Implementing gradients for tf.image.resize_bicubic.
PiperOrigin-RevId: 168547412
This commit is contained in:
parent
4982ef0fa4
commit
3331c574bc
@ -134,6 +134,7 @@ op { name: "WholeFileReaderV2" rename_to: "WholeFileReader" }
|
|||||||
# image_ops
|
# image_ops
|
||||||
op { name: "AdjustContrastv2" rename_to: "AdjustContrast" }
|
op { name: "AdjustContrastv2" rename_to: "AdjustContrast" }
|
||||||
op { name: "ResizeBilinearGrad" hide: true }
|
op { name: "ResizeBilinearGrad" hide: true }
|
||||||
|
op { name: "ResizeBicubicGrad" hide: true }
|
||||||
op { name: "ResizeNearestNeighborGrad" hide: true }
|
op { name: "ResizeNearestNeighborGrad" hide: true }
|
||||||
|
|
||||||
# io_ops
|
# io_ops
|
||||||
|
@ -180,6 +180,21 @@ static void ComputeXWeightsAndIndices(const ImageResizerState& resizer_state,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ComputeGradientXWeightsAndIndices(
|
||||||
|
const ImageResizerGradientState& resizer_state,
|
||||||
|
std::vector<WeightsAndIndices>* x_wais) {
|
||||||
|
CachedInterpolationCalculator calc;
|
||||||
|
for (int64 x = 0; x < resizer_state.resized_width; ++x) {
|
||||||
|
GetWeightsAndIndices(resizer_state.width_scale, x,
|
||||||
|
resizer_state.original_width, &(*x_wais)[x]);
|
||||||
|
auto& x_wai = (*x_wais)[x];
|
||||||
|
x_wai.advance = calc.Advance(x_wai.index_0, x_wai.index_1, x_wai.index_2,
|
||||||
|
x_wai.index_3);
|
||||||
|
}
|
||||||
|
// Do not scale, as we will be using these directly as tensor indices on the
|
||||||
|
// gradient pass.
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static EIGEN_ALWAYS_INLINE float ComputeYInterpolation(
|
static EIGEN_ALWAYS_INLINE float ComputeYInterpolation(
|
||||||
int which, int channel_num, const WeightsAndIndices& y_wai,
|
int which, int channel_num, const WeightsAndIndices& y_wai,
|
||||||
@ -365,6 +380,73 @@ inline void interpolate_with_caching(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void ResizeBicubicGrad(typename TTypes<float, 4>::ConstTensor input_grad,
|
||||||
|
const ImageResizerGradientState& resizer_state,
|
||||||
|
typename TTypes<T, 4>::Tensor output_grad) {
|
||||||
|
// This function computes gradients for the ResizeBicubic op by iterating over
|
||||||
|
// the input_grad Tensor and using WeightsAndIndices to appropriately update
|
||||||
|
// the output gradient.
|
||||||
|
const float height_scale = resizer_state.height_scale;
|
||||||
|
const int64 original_height = resizer_state.original_height;
|
||||||
|
const int channels = resizer_state.channels;
|
||||||
|
const int64 resized_width = resizer_state.resized_width;
|
||||||
|
const int64 resized_height = resizer_state.resized_height;
|
||||||
|
|
||||||
|
output_grad.setZero();
|
||||||
|
|
||||||
|
std::vector<WeightsAndIndices> x_wais(resizer_state.resized_width);
|
||||||
|
ComputeGradientXWeightsAndIndices(resizer_state, &x_wais);
|
||||||
|
for (int64 b = 0; b < resizer_state.batch_size; ++b) {
|
||||||
|
for (int64 y = 0; y < resized_height; ++y) {
|
||||||
|
WeightsAndIndices y_wai;
|
||||||
|
GetWeightsAndIndices(height_scale, y, original_height, &y_wai);
|
||||||
|
for (int64 x = 0; x < resized_width; ++x) {
|
||||||
|
const WeightsAndIndices& x_wai = x_wais[x];
|
||||||
|
for (int64 c = 0; c < channels; ++c) {
|
||||||
|
T curr_input_grad = input_grad(b, y, x, c);
|
||||||
|
// row 0 of 0, 1, 2, 3
|
||||||
|
output_grad(b, y_wai.index_0, x_wai.index_0, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_0);
|
||||||
|
output_grad(b, y_wai.index_0, x_wai.index_1, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_1);
|
||||||
|
output_grad(b, y_wai.index_0, x_wai.index_2, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_2);
|
||||||
|
output_grad(b, y_wai.index_0, x_wai.index_3, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_0 * x_wai.weight_3);
|
||||||
|
// row 1 of 0, 1, 2, 3
|
||||||
|
output_grad(b, y_wai.index_1, x_wai.index_0, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_0);
|
||||||
|
output_grad(b, y_wai.index_1, x_wai.index_1, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_1);
|
||||||
|
output_grad(b, y_wai.index_1, x_wai.index_2, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_2);
|
||||||
|
output_grad(b, y_wai.index_1, x_wai.index_3, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_1 * x_wai.weight_3);
|
||||||
|
// row 2 of 0, 1, 2, 3
|
||||||
|
output_grad(b, y_wai.index_2, x_wai.index_0, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_0);
|
||||||
|
output_grad(b, y_wai.index_2, x_wai.index_1, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_1);
|
||||||
|
output_grad(b, y_wai.index_2, x_wai.index_2, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_2);
|
||||||
|
output_grad(b, y_wai.index_2, x_wai.index_3, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_2 * x_wai.weight_3);
|
||||||
|
// row 3 of 0, 1, 2, 3
|
||||||
|
output_grad(b, y_wai.index_3, x_wai.index_0, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_0);
|
||||||
|
output_grad(b, y_wai.index_3, x_wai.index_1, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_1);
|
||||||
|
output_grad(b, y_wai.index_3, x_wai.index_2, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_2);
|
||||||
|
output_grad(b, y_wai.index_3, x_wai.index_3, c) +=
|
||||||
|
T(curr_input_grad * y_wai.weight_3 * x_wai.weight_3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
@ -394,6 +476,36 @@ class ResizeBicubicOp : public OpKernel {
|
|||||||
bool align_corners_;
|
bool align_corners_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class ResizeBicubicOpGrad : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit ResizeBicubicOpGrad(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
// Validate input.
|
||||||
|
// First argument is gradient with respect to resized image.
|
||||||
|
const Tensor& input = context->input(0);
|
||||||
|
const Tensor& original_image = context->input(1);
|
||||||
|
|
||||||
|
ImageResizerGradientState st(align_corners_);
|
||||||
|
st.ValidateAndCreateOutput(context, input, original_image);
|
||||||
|
|
||||||
|
if (!context->status().ok()) return;
|
||||||
|
|
||||||
|
typename TTypes<float, 4>::ConstTensor input_grad =
|
||||||
|
input.tensor<float, 4>();
|
||||||
|
typename TTypes<T, 4>::Tensor output_grad = st.output->tensor<T, 4>();
|
||||||
|
|
||||||
|
ResizeBicubicGrad<T>(input_grad, st, output_grad);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool align_corners_;
|
||||||
|
};
|
||||||
|
|
||||||
#define REGISTER_KERNEL(T) \
|
#define REGISTER_KERNEL(T) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("ResizeBicubic") \
|
REGISTER_KERNEL_BUILDER(Name("ResizeBicubic") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
@ -405,4 +517,14 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
|
|||||||
|
|
||||||
#undef REGISTER_KERNEL
|
#undef REGISTER_KERNEL
|
||||||
|
|
||||||
|
#define REGISTER_GRAD_KERNEL(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("ResizeBicubicGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
|
ResizeBicubicOpGrad<CPUDevice, T>);
|
||||||
|
|
||||||
|
TF_CALL_float(REGISTER_GRAD_KERNEL);
|
||||||
|
TF_CALL_double(REGISTER_GRAD_KERNEL);
|
||||||
|
|
||||||
|
#undef REGISTER_GRAD_KERNEL
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -197,6 +197,31 @@ resized_images: 4-D with shape
|
|||||||
`[batch, new_height, new_width, channels]`.
|
`[batch, new_height, new_width, channels]`.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
REGISTER_OP("ResizeBicubicGrad")
|
||||||
|
.Input("grads: float")
|
||||||
|
.Input("original_image: T")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("T: {float, double}")
|
||||||
|
.Attr("align_corners: bool = false")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
c->set_output(0, c->input(1));
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.Doc(R"doc(
|
||||||
|
Computes the gradient of bicubic interpolation.
|
||||||
|
|
||||||
|
grads: 4-D with shape `[batch, height, width, channels]`.
|
||||||
|
original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`,
|
||||||
|
The image tensor that was resized.
|
||||||
|
align_corners: If true, rescale grads by (orig_height - 1) / (height - 1), which
|
||||||
|
exactly aligns the 4 corners of grads and original_image. If false, rescale by
|
||||||
|
orig_height / height. Treat similarly the width dimension.
|
||||||
|
output: 4-D with shape `[batch, orig_height, orig_width, channels]`.
|
||||||
|
Gradients with respect to the input image. Input image must have been
|
||||||
|
float or double.
|
||||||
|
)doc");
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
REGISTER_OP("ResizeBilinear")
|
REGISTER_OP("ResizeBilinear")
|
||||||
.Input("images: T")
|
.Input("images: T")
|
||||||
|
@ -20176,6 +20176,43 @@ op {
|
|||||||
summary: "Resize `images` to `size` using bicubic interpolation."
|
summary: "Resize `images` to `size` using bicubic interpolation."
|
||||||
description: "Input images can be of different types but output images are always float."
|
description: "Input images can be of different types but output images are always float."
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "ResizeBicubicGrad"
|
||||||
|
input_arg {
|
||||||
|
name: "grads"
|
||||||
|
description: "4-D with shape `[batch, height, width, channels]`."
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "original_image"
|
||||||
|
description: "4-D with shape `[batch, orig_height, orig_width, channels]`,\nThe image tensor that was resized."
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
description: "4-D with shape `[batch, orig_height, orig_width, channels]`.\nGradients with respect to the input image. Input image must have been\nfloat or double."
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "align_corners"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
description: "If true, rescale grads by (orig_height - 1) / (height - 1), which\nexactly aligns the 4 corners of grads and original_image. If false, rescale by\norig_height / height. Treat similarly the width dimension."
|
||||||
|
}
|
||||||
|
summary: "Computes the gradient of bicubic interpolation."
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "ResizeBilinear"
|
name: "ResizeBilinear"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -167,6 +167,7 @@ NonMaxSuppression
|
|||||||
NonMaxSuppressionV2
|
NonMaxSuppressionV2
|
||||||
RandomCrop
|
RandomCrop
|
||||||
ResizeBilinearGrad
|
ResizeBilinearGrad
|
||||||
|
ResizeBicubicGrad
|
||||||
ResizeNearestNeighborGrad
|
ResizeNearestNeighborGrad
|
||||||
SampleDistortedBoundingBox
|
SampleDistortedBoundingBox
|
||||||
SampleDistortedBoundingBoxV2
|
SampleDistortedBoundingBoxV2
|
||||||
|
@ -73,6 +73,27 @@ def _ResizeBilinearGrad(op, grad):
|
|||||||
return [grad0, None]
|
return [grad0, None]
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("ResizeBicubic")
|
||||||
|
def _ResizeBicubicGrad(op, grad):
|
||||||
|
"""The derivatives for bicubic resizing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: The ResizeBicubic op.
|
||||||
|
grad: The tensor representing the gradient w.r.t. the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The gradients w.r.t. the input.
|
||||||
|
"""
|
||||||
|
allowed_types = [dtypes.float32, dtypes.float64]
|
||||||
|
grad0 = None
|
||||||
|
if op.inputs[0].dtype in allowed_types:
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
grad0 = gen_image_ops._resize_bicubic_grad(
|
||||||
|
grad, op.inputs[0], align_corners=op.get_attr("align_corners"))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
return [grad0, None]
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("CropAndResize")
|
@ops.RegisterGradient("CropAndResize")
|
||||||
def _CropAndResizeGrad(op, grad):
|
def _CropAndResizeGrad(op, grad):
|
||||||
"""The derivatives for crop_and_resize.
|
"""The derivatives for crop_and_resize.
|
||||||
|
@ -173,6 +173,67 @@ class ResizeBilinearOpTest(test.TestCase):
|
|||||||
self.assertAllClose(grad[False], grad[True], rtol=1e-4, atol=1e-4)
|
self.assertAllClose(grad[False], grad[True], rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeBicubicOpTest(test.TestCase):
|
||||||
|
|
||||||
|
def testShapeIsCorrectAfterOp(self):
|
||||||
|
in_shape = [1, 2, 2, 1]
|
||||||
|
out_shape = [1, 4, 6, 1]
|
||||||
|
|
||||||
|
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
|
||||||
|
|
||||||
|
for align_corners in [True, False]:
|
||||||
|
with self.test_session() as sess:
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
|
||||||
|
align_corners=align_corners)
|
||||||
|
self.assertEqual(out_shape, list(resize_out.get_shape()))
|
||||||
|
|
||||||
|
resize_out = sess.run(resize_out)
|
||||||
|
self.assertEqual(out_shape, list(resize_out.shape))
|
||||||
|
|
||||||
|
def testGradFromResizeToLargerInBothDims(self):
|
||||||
|
in_shape = [1, 2, 3, 1]
|
||||||
|
out_shape = [1, 4, 6, 1]
|
||||||
|
|
||||||
|
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
|
||||||
|
|
||||||
|
for align_corners in [True, False]:
|
||||||
|
with self.test_session():
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
|
||||||
|
align_corners=align_corners)
|
||||||
|
err = gradient_checker.compute_gradient_error(
|
||||||
|
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
|
||||||
|
self.assertLess(err, 1e-3)
|
||||||
|
|
||||||
|
def testGradFromResizeToSmallerInBothDims(self):
|
||||||
|
in_shape = [1, 4, 6, 1]
|
||||||
|
out_shape = [1, 2, 3, 1]
|
||||||
|
|
||||||
|
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
|
||||||
|
|
||||||
|
for align_corners in [True, False]:
|
||||||
|
with self.test_session():
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
|
||||||
|
align_corners=align_corners)
|
||||||
|
err = gradient_checker.compute_gradient_error(
|
||||||
|
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
|
||||||
|
self.assertLess(err, 1e-3)
|
||||||
|
|
||||||
|
def testGradOnUnsupportedType(self):
|
||||||
|
in_shape = [1, 4, 6, 1]
|
||||||
|
out_shape = [1, 2, 3, 1]
|
||||||
|
|
||||||
|
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3])
|
||||||
|
grad = gradients_impl.gradients(input_tensor, [resize_out])
|
||||||
|
self.assertEqual([None], grad)
|
||||||
|
|
||||||
|
|
||||||
class CropAndResizeOpTest(test.TestCase):
|
class CropAndResizeOpTest(test.TestCase):
|
||||||
|
|
||||||
def testShapeIsCorrectAfterOp(self):
|
def testShapeIsCorrectAfterOp(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user