Add bilinear interpolation to tf.contrib.image.
Change: 155247916
This commit is contained in:
parent
fe16da297c
commit
37e3b71b49
@ -43,13 +43,29 @@ template class FillProjectiveTransform<CPUDevice, double>;
|
|||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
using functor::FillProjectiveTransform;
|
using functor::FillProjectiveTransform;
|
||||||
|
using generator::INTERPOLATION_BILINEAR;
|
||||||
|
using generator::INTERPOLATION_NEAREST;
|
||||||
|
using generator::Interpolation;
|
||||||
using generator::ProjectiveGenerator;
|
using generator::ProjectiveGenerator;
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class ImageProjectiveTransform : public OpKernel {
|
class ImageProjectiveTransform : public OpKernel {
|
||||||
|
private:
|
||||||
|
Interpolation interpolation_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit ImageProjectiveTransform(OpKernelConstruction* ctx)
|
explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
: OpKernel(ctx) {}
|
string interpolation_str;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
|
||||||
|
if (interpolation_str == "NEAREST") {
|
||||||
|
interpolation_ = INTERPOLATION_NEAREST;
|
||||||
|
} else if (interpolation_str == "BILINEAR") {
|
||||||
|
interpolation_ = INTERPOLATION_BILINEAR;
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Invalid interpolation " << interpolation_str
|
||||||
|
<< ". Supported types: NEAREST, BILINEAR";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
const Tensor& images_t = ctx->input(0);
|
const Tensor& images_t = ctx->input(0);
|
||||||
@ -68,8 +84,8 @@ class ImageProjectiveTransform : public OpKernel {
|
|||||||
Tensor* output_t;
|
Tensor* output_t;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
|
||||||
auto output = output_t->tensor<T, 4>();
|
auto output = output_t->tensor<T, 4>();
|
||||||
const FillProjectiveTransform<Device, T> functor;
|
(FillProjectiveTransform<Device, T>(interpolation_))(
|
||||||
functor(ctx->eigen_device<Device>(), &output, images, transform);
|
ctx->eigen_device<Device>(), &output, images, transform);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -28,6 +28,8 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace generator {
|
namespace generator {
|
||||||
|
|
||||||
|
enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
|
||||||
|
|
||||||
using Eigen::array;
|
using Eigen::array;
|
||||||
using Eigen::DenseIndex;
|
using Eigen::DenseIndex;
|
||||||
|
|
||||||
@ -36,20 +38,19 @@ class ProjectiveGenerator {
|
|||||||
private:
|
private:
|
||||||
typename TTypes<T, 4>::ConstTensor input_;
|
typename TTypes<T, 4>::ConstTensor input_;
|
||||||
typename TTypes<float>::ConstMatrix transforms_;
|
typename TTypes<float>::ConstMatrix transforms_;
|
||||||
|
const Interpolation interpolation_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static const int kNumParameters = 8;
|
static const int kNumParameters = 8;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||||
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
|
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
|
||||||
typename TTypes<float>::ConstMatrix transforms)
|
typename TTypes<float>::ConstMatrix transforms,
|
||||||
: input_(input), transforms_(transforms) {}
|
const Interpolation interpolation)
|
||||||
|
: input_(input), transforms_(transforms), interpolation_(interpolation) {}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||||
operator()(const array<DenseIndex, 4>& coords) const {
|
operator()(const array<DenseIndex, 4>& coords) const {
|
||||||
array<DenseIndex, 4> input_coords;
|
|
||||||
input_coords[0] = coords[0];
|
|
||||||
|
|
||||||
const int64 output_y = coords[1];
|
const int64 output_y = coords[1];
|
||||||
const int64 output_x = coords[2];
|
const int64 output_x = coords[2];
|
||||||
const float* transform =
|
const float* transform =
|
||||||
@ -57,24 +58,73 @@ class ProjectiveGenerator {
|
|||||||
? transforms_.data()
|
? transforms_.data()
|
||||||
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
|
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
|
||||||
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
|
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
|
||||||
const int64 input_x = std::round(
|
const float input_x =
|
||||||
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
|
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
|
||||||
projection);
|
projection;
|
||||||
const int64 input_y = std::round(
|
const float input_y =
|
||||||
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
|
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
|
||||||
projection);
|
projection;
|
||||||
|
|
||||||
if (!(0 <= input_y && input_y < input_.dimension(1) && 0 <= input_x &&
|
|
||||||
input_x < input_.dimension(2))) {
|
|
||||||
// TODO(ringwalt): Add a fill value input.
|
// TODO(ringwalt): Add a fill value input.
|
||||||
return T(0);
|
static const T fill_value = T(0);
|
||||||
|
switch (interpolation_) {
|
||||||
|
case INTERPOLATION_NEAREST:
|
||||||
|
// Switch the order of x and y again for indexing into the image.
|
||||||
|
return nearest_interpolation(coords[0], input_y, input_x, coords[3],
|
||||||
|
fill_value);
|
||||||
|
case INTERPOLATION_BILINEAR:
|
||||||
|
return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
|
||||||
|
fill_value);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
input_coords[1] = input_y;
|
|
||||||
input_coords[2] = input_x;
|
|
||||||
|
|
||||||
input_coords[3] = coords[3];
|
private:
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||||
|
nearest_interpolation(const DenseIndex batch, const float y, const float x,
|
||||||
|
const DenseIndex channel, const T fill_value) const {
|
||||||
|
return read_with_fill_value(batch, DenseIndex(std::round(y)),
|
||||||
|
DenseIndex(std::round(x)), channel, fill_value);
|
||||||
|
}
|
||||||
|
|
||||||
return input_(input_coords);
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||||
|
bilinear_interpolation(const DenseIndex batch, const float y, const float x,
|
||||||
|
const DenseIndex channel, const T fill_value) const {
|
||||||
|
const float y_floor = std::floor(y);
|
||||||
|
const float x_floor = std::floor(x);
|
||||||
|
const float y_ceil = y_floor + 1;
|
||||||
|
const float x_ceil = x_floor + 1;
|
||||||
|
// f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
|
||||||
|
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
|
||||||
|
const float value_yfloor =
|
||||||
|
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor),
|
||||||
|
DenseIndex(x_floor), channel,
|
||||||
|
fill_value) +
|
||||||
|
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor),
|
||||||
|
DenseIndex(x_ceil), channel,
|
||||||
|
fill_value);
|
||||||
|
// f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
|
||||||
|
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
|
||||||
|
const float value_yceil =
|
||||||
|
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil),
|
||||||
|
DenseIndex(x_floor), channel,
|
||||||
|
fill_value) +
|
||||||
|
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil),
|
||||||
|
DenseIndex(x_ceil), channel,
|
||||||
|
fill_value);
|
||||||
|
// f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
|
||||||
|
// + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
|
||||||
|
return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
|
||||||
|
const DenseIndex batch, const DenseIndex y, const DenseIndex x,
|
||||||
|
const DenseIndex channel, const T fill_value) const {
|
||||||
|
// batch and channel must be correct, because they are passed unchanged from
|
||||||
|
// the input.
|
||||||
|
return (0 <= y && y < input_.dimension(1) && 0 <= x &&
|
||||||
|
x < input_.dimension(2))
|
||||||
|
? input_(array<DenseIndex, 4>{batch, y, x, channel})
|
||||||
|
: fill_value;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -85,6 +135,7 @@ class ProjectiveGenerator {
|
|||||||
// some Eigen device code.
|
// some Eigen device code.
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
|
using generator::Interpolation;
|
||||||
using generator::ProjectiveGenerator;
|
using generator::ProjectiveGenerator;
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
@ -92,15 +143,17 @@ struct FillProjectiveTransform {
|
|||||||
typedef typename TTypes<T, 4>::Tensor OutputType;
|
typedef typename TTypes<T, 4>::Tensor OutputType;
|
||||||
typedef typename TTypes<T, 4>::ConstTensor InputType;
|
typedef typename TTypes<T, 4>::ConstTensor InputType;
|
||||||
typedef typename TTypes<float, 2>::ConstTensor TransformsType;
|
typedef typename TTypes<float, 2>::ConstTensor TransformsType;
|
||||||
|
const Interpolation interpolation_;
|
||||||
|
|
||||||
FillProjectiveTransform() {}
|
FillProjectiveTransform(Interpolation interpolation)
|
||||||
|
: interpolation_(interpolation) {}
|
||||||
|
|
||||||
EIGEN_ALWAYS_INLINE
|
EIGEN_ALWAYS_INLINE
|
||||||
void operator()(const Device& device, OutputType* output,
|
void operator()(const Device& device, OutputType* output,
|
||||||
const InputType& images,
|
const InputType& images,
|
||||||
const TransformsType& transform) const {
|
const TransformsType& transform) const {
|
||||||
ProjectiveGenerator<Device, T> generator(images, transform);
|
output->device(device) = images.generate(
|
||||||
output->device(device) = images.generate(generator);
|
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -23,13 +23,13 @@ using shape_inference::InferenceContext;
|
|||||||
|
|
||||||
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
|
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
|
||||||
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
|
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
|
||||||
// TODO(ringwalt): Add an "interpolation" argument with "none", "bilinear", etc.
|
|
||||||
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
|
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
|
||||||
// implement "same" and "valid" modes in the Python function.
|
// implement "same" and "valid" modes in the Python function.
|
||||||
REGISTER_OP("ImageProjectiveTransform")
|
REGISTER_OP("ImageProjectiveTransform")
|
||||||
.Input("images: dtype")
|
.Input("images: dtype")
|
||||||
.Input("transforms: float32")
|
.Input("transforms: float32")
|
||||||
.Attr("dtype: {uint8, int32, int64, float32, float64}")
|
.Attr("dtype: {uint8, int32, int64, float32, float64}")
|
||||||
|
.Attr("interpolation: string")
|
||||||
.Output("transformed_images: dtype")
|
.Output("transformed_images: dtype")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
c->set_output(0, c->input(0));
|
c->set_output(0, c->input(0));
|
||||||
|
@ -111,6 +111,55 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
|
|||||||
[0, 1, 0, 1],
|
[0, 1, 0, 1],
|
||||||
[0, 1, 1, 1]])
|
[0, 1, 1, 1]])
|
||||||
|
|
||||||
|
def test_bilinear(self):
|
||||||
|
with self.test_session():
|
||||||
|
image = constant_op.constant(
|
||||||
|
[[0, 0, 0, 0, 0],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[0, 1, 0, 1, 0],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[0, 0, 0, 0, 0]],
|
||||||
|
dtypes.float32)
|
||||||
|
# The following result matches:
|
||||||
|
# >>> scipy.ndimage.rotate(image, 45, order=1, reshape=False)
|
||||||
|
# which uses spline interpolation of order 1, equivalent to bilinear
|
||||||
|
# interpolation.
|
||||||
|
self.assertAllClose(
|
||||||
|
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
|
||||||
|
[[0.000, 0.000, 0.343, 0.000, 0.000],
|
||||||
|
[0.000, 0.586, 0.914, 0.586, 0.000],
|
||||||
|
[0.343, 0.914, 0.000, 0.914, 0.343],
|
||||||
|
[0.000, 0.586, 0.914, 0.586, 0.000],
|
||||||
|
[0.000, 0.000, 0.343, 0.000, 0.000]],
|
||||||
|
atol=0.001)
|
||||||
|
self.assertAllClose(
|
||||||
|
image_ops.rotate(image, np.pi / 4.0, interpolation="NEAREST").eval(),
|
||||||
|
[[0, 0, 1, 0, 0],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[1, 1, 0, 1, 1],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[0, 0, 1, 0, 0]])
|
||||||
|
|
||||||
|
def test_bilinear_uint8(self):
|
||||||
|
with self.test_session():
|
||||||
|
image = constant_op.constant(
|
||||||
|
np.asarray(
|
||||||
|
[[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 255, 255, 255, 0.0],
|
||||||
|
[0.0, 255, 0.0, 255, 0.0],
|
||||||
|
[0.0, 255, 255, 255, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||||
|
np.uint8),
|
||||||
|
dtypes.uint8)
|
||||||
|
# == np.rint((expected image above) * 255)
|
||||||
|
self.assertAllEqual(
|
||||||
|
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
|
||||||
|
[[0.0, 0.0, 87., 0.0, 0.0],
|
||||||
|
[0.0, 149, 233, 149, 0.0],
|
||||||
|
[87., 233, 0.0, 233, 87.],
|
||||||
|
[0.0, 149, 233, 149, 0.0],
|
||||||
|
[0.0, 0.0, 87., 0.0, 0.0]])
|
||||||
|
|
||||||
def _test_grad(self, shape_to_test):
|
def _test_grad(self, shape_to_test):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
test_image_shape = shape_to_test
|
test_image_shape = shape_to_test
|
||||||
|
@ -37,7 +37,7 @@ _IMAGE_DTYPES = set(
|
|||||||
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
|
||||||
|
|
||||||
|
|
||||||
def rotate(images, angles):
|
def rotate(images, angles, interpolation="NEAREST"):
|
||||||
"""Rotate image(s) by the passed angle(s) in radians.
|
"""Rotate image(s) by the passed angle(s) in radians.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -46,6 +46,7 @@ def rotate(images, angles):
|
|||||||
(num_rows, num_columns) (HW).
|
(num_rows, num_columns) (HW).
|
||||||
angles: A scalar angle to rotate all images by, or (if images has rank 4)
|
angles: A scalar angle to rotate all images by, or (if images has rank 4)
|
||||||
a vector of length num_images, with an angle for each image in the batch.
|
a vector of length num_images, with an angle for each image in the batch.
|
||||||
|
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Image(s) with the same type and shape as `images`, rotated by the given
|
Image(s) with the same type and shape as `images`, rotated by the given
|
||||||
@ -70,7 +71,8 @@ def rotate(images, angles):
|
|||||||
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
|
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
|
||||||
output = transform(
|
output = transform(
|
||||||
images,
|
images,
|
||||||
angles_to_projective_transforms(angles, image_width, image_height))
|
angles_to_projective_transforms(angles, image_height, image_width),
|
||||||
|
interpolation=interpolation)
|
||||||
if len(image_or_images.get_shape()) == 2:
|
if len(image_or_images.get_shape()) == 2:
|
||||||
return output[0, :, :, 0]
|
return output[0, :, :, 0]
|
||||||
elif len(image_or_images.get_shape()) == 3:
|
elif len(image_or_images.get_shape()) == 3:
|
||||||
@ -120,7 +122,7 @@ def angles_to_projective_transforms(angles, image_height, image_width):
|
|||||||
axis=1)
|
axis=1)
|
||||||
|
|
||||||
|
|
||||||
def transform(images, transforms):
|
def transform(images, transforms, interpolation="NEAREST"):
|
||||||
"""Applies the given transform(s) to the image(s).
|
"""Applies the given transform(s) to the image(s).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -134,6 +136,7 @@ def transform(images, transforms):
|
|||||||
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
|
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
|
||||||
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
|
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
|
||||||
the transform mapping input points to output points.
|
the transform mapping input points to output points.
|
||||||
|
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Image(s) with the same type and shape as `images`, with the given
|
Image(s) with the same type and shape as `images`, with the given
|
||||||
@ -163,8 +166,8 @@ def transform(images, transforms):
|
|||||||
transforms = transform_or_transforms
|
transforms = transform_or_transforms
|
||||||
else:
|
else:
|
||||||
raise TypeError("Transforms should have rank 1 or 2.")
|
raise TypeError("Transforms should have rank 1 or 2.")
|
||||||
# pylint: disable=protected-access
|
output = gen_image_ops.image_projective_transform(
|
||||||
output = gen_image_ops.image_projective_transform(images, transforms)
|
images, transforms, interpolation=interpolation.upper())
|
||||||
if len(image_or_images.get_shape()) == 2:
|
if len(image_or_images.get_shape()) == 2:
|
||||||
return output[0, :, :, 0]
|
return output[0, :, :, 0]
|
||||||
elif len(image_or_images.get_shape()) == 3:
|
elif len(image_or_images.get_shape()) == 3:
|
||||||
@ -220,6 +223,7 @@ def _image_projective_transform_grad(op, grad):
|
|||||||
"""Computes the gradient for ImageProjectiveTransform."""
|
"""Computes the gradient for ImageProjectiveTransform."""
|
||||||
images = op.inputs[0]
|
images = op.inputs[0]
|
||||||
transforms = op.inputs[1]
|
transforms = op.inputs[1]
|
||||||
|
interpolation = op.get_attr("interpolation")
|
||||||
|
|
||||||
image_or_images = ops.convert_to_tensor(images, name="images")
|
image_or_images = ops.convert_to_tensor(images, name="images")
|
||||||
transform_or_transforms = ops.convert_to_tensor(
|
transform_or_transforms = ops.convert_to_tensor(
|
||||||
@ -246,7 +250,8 @@ def _image_projective_transform_grad(op, grad):
|
|||||||
transforms = _flat_transforms_to_matrices(transforms=transforms)
|
transforms = _flat_transforms_to_matrices(transforms=transforms)
|
||||||
inverse = linalg_ops.matrix_inverse(transforms)
|
inverse = linalg_ops.matrix_inverse(transforms)
|
||||||
transforms = _transform_matrices_to_flat(inverse)
|
transforms = _transform_matrices_to_flat(inverse)
|
||||||
output = gen_image_ops.image_projective_transform(grad, transforms)
|
output = gen_image_ops.image_projective_transform(
|
||||||
|
grad, transforms, interpolation=interpolation)
|
||||||
if len(image_or_images.get_shape()) == 2:
|
if len(image_or_images.get_shape()) == 2:
|
||||||
return [output[0, :, :, 0], None]
|
return [output[0, :, :, 0], None]
|
||||||
elif len(image_or_images.get_shape()) == 3:
|
elif len(image_or_images.get_shape()) == 3:
|
||||||
|
Loading…
Reference in New Issue
Block a user