Merge pull request #25467 from a6802739:crop_and_resize_gradient
PiperOrigin-RevId: 236007758
This commit is contained in:
commit
746397a4ed
@ -99,6 +99,25 @@ Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op,
|
|||||||
}
|
}
|
||||||
REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper);
|
REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper);
|
||||||
|
|
||||||
|
Status CropAndResizeGradHelper(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
DataType input_type;
|
||||||
|
string method;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "method", &method));
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "T", &input_type));
|
||||||
|
auto image_shape = Shape(scope, op.input(0));
|
||||||
|
grad_outputs->push_back(CropAndResizeGradImage(
|
||||||
|
scope, grad_inputs[0], op.input(1), op.input(2), image_shape, input_type,
|
||||||
|
CropAndResizeGradImage::Method(method)));
|
||||||
|
grad_outputs->push_back(CropAndResizeGradBoxes(
|
||||||
|
scope, grad_inputs[0], op.input(0), op.input(1), op.input(2)));
|
||||||
|
grad_outputs->push_back(NoGradient());
|
||||||
|
grad_outputs->push_back(NoGradient());
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRADIENT_OP("CropAndResize", CropAndResizeGradHelper);
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -27,6 +27,7 @@ namespace tensorflow {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ops::Const;
|
using ops::Const;
|
||||||
|
using ops::CropAndResize;
|
||||||
using ops::ResizeBicubic;
|
using ops::ResizeBicubic;
|
||||||
using ops::ResizeBilinear;
|
using ops::ResizeBilinear;
|
||||||
using ops::ResizeNearestNeighbor;
|
using ops::ResizeNearestNeighbor;
|
||||||
@ -219,5 +220,50 @@ class ScaleAndTranslateGradTest : public ::testing::Test {
|
|||||||
|
|
||||||
TEST_F(ScaleAndTranslateGradTest, Works) { TestResize<float, float, float>(); }
|
TEST_F(ScaleAndTranslateGradTest, Works) { TestResize<float, float, float>(); }
|
||||||
|
|
||||||
|
class CropAndResizeGradTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
CropAndResizeGradTest() : scope_(Scope::NewRootScope()) {}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Tensor MakeData(const TensorShape& data_shape) {
|
||||||
|
DataType data_type = DataTypeToEnum<T>::v();
|
||||||
|
Tensor data(data_type, data_shape);
|
||||||
|
auto data_flat = data.flat<T>();
|
||||||
|
for (int i = 0; i < data_flat.size(); ++i) {
|
||||||
|
data_flat(i) = T(i);
|
||||||
|
}
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void MakeOp(const Tensor& x_data, const Input& boxes, const Input& box_ind,
|
||||||
|
const Input& crop_szie, Output* x, Output* y) {
|
||||||
|
*x = Const<T>(scope_, x_data);
|
||||||
|
*y = CropAndResize(scope_, *x, boxes, box_ind, crop_szie,
|
||||||
|
CropAndResize::Method("bilinear"));
|
||||||
|
TF_ASSERT_OK(scope_.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename X_T, typename Y_T, typename JAC_T>
|
||||||
|
void TestCropAndResize() {
|
||||||
|
TensorShape x_shape({1, 4, 2, 1});
|
||||||
|
Tensor x_data = MakeData<X_T>(x_shape);
|
||||||
|
TensorShape box_shape({1, 4});
|
||||||
|
Tensor boxes = MakeData<X_T>(box_shape);
|
||||||
|
Output x, y;
|
||||||
|
MakeOp<X_T>(x_data, boxes, {0}, {1, 1}, &x, &y);
|
||||||
|
JAC_T max_error;
|
||||||
|
TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, JAC_T>(
|
||||||
|
scope_, x, x_data, y, {1, 1, 1, 1}, &max_error)));
|
||||||
|
EXPECT_LT(max_error, 1e-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
Scope scope_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(CropAndResizeGradTest, TestCrop) {
|
||||||
|
TestCropAndResize<float, float, float>();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
x
Reference in New Issue
Block a user