Add bilinear interpolation to tf.contrib.image.

Change: 155247916
This commit is contained in:
Dan Ringwalt 2017-05-05 14:28:59 -08:00 committed by TensorFlower Gardener
parent fe16da297c
commit 37e3b71b49
5 changed files with 154 additions and 31 deletions

View File

@ -43,13 +43,29 @@ template class FillProjectiveTransform<CPUDevice, double>;
typedef Eigen::ThreadPoolDevice CPUDevice;
using functor::FillProjectiveTransform;
using generator::INTERPOLATION_BILINEAR;
using generator::INTERPOLATION_NEAREST;
using generator::Interpolation;
using generator::ProjectiveGenerator;
template <typename Device, typename T>
class ImageProjectiveTransform : public OpKernel {
private:
Interpolation interpolation_;
public:
explicit ImageProjectiveTransform(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
string interpolation_str;
OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
if (interpolation_str == "NEAREST") {
interpolation_ = INTERPOLATION_NEAREST;
} else if (interpolation_str == "BILINEAR") {
interpolation_ = INTERPOLATION_BILINEAR;
} else {
LOG(FATAL) << "Invalid interpolation " << interpolation_str
<< ". Supported types: NEAREST, BILINEAR";
}
}
void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0);
@ -68,8 +84,8 @@ class ImageProjectiveTransform : public OpKernel {
Tensor* output_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
auto output = output_t->tensor<T, 4>();
const FillProjectiveTransform<Device, T> functor;
functor(ctx->eigen_device<Device>(), &output, images, transform);
(FillProjectiveTransform<Device, T>(interpolation_))(
ctx->eigen_device<Device>(), &output, images, transform);
}
};

View File

@ -28,6 +28,8 @@ namespace tensorflow {
namespace generator {
enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
using Eigen::array;
using Eigen::DenseIndex;
@ -36,20 +38,19 @@ class ProjectiveGenerator {
private:
typename TTypes<T, 4>::ConstTensor input_;
typename TTypes<float>::ConstMatrix transforms_;
const Interpolation interpolation_;
public:
static const int kNumParameters = 8;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
typename TTypes<float>::ConstMatrix transforms)
: input_(input), transforms_(transforms) {}
typename TTypes<float>::ConstMatrix transforms,
const Interpolation interpolation)
: input_(input), transforms_(transforms), interpolation_(interpolation) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const array<DenseIndex, 4>& coords) const {
array<DenseIndex, 4> input_coords;
input_coords[0] = coords[0];
const int64 output_y = coords[1];
const int64 output_x = coords[2];
const float* transform =
@ -57,24 +58,73 @@ class ProjectiveGenerator {
? transforms_.data()
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
const int64 input_x = std::round(
const float input_x =
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
projection);
const int64 input_y = std::round(
projection;
const float input_y =
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
projection);
projection;
if (!(0 <= input_y && input_y < input_.dimension(1) && 0 <= input_x &&
input_x < input_.dimension(2))) {
// TODO(ringwalt): Add a fill value input.
return T(0);
// TODO(ringwalt): Add a fill value input.
static const T fill_value = T(0);
switch (interpolation_) {
case INTERPOLATION_NEAREST:
// Switch the order of x and y again for indexing into the image.
return nearest_interpolation(coords[0], input_y, input_x, coords[3],
fill_value);
case INTERPOLATION_BILINEAR:
return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
fill_value);
}
input_coords[1] = input_y;
input_coords[2] = input_x;
}
input_coords[3] = coords[3];
private:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
nearest_interpolation(const DenseIndex batch, const float y, const float x,
const DenseIndex channel, const T fill_value) const {
return read_with_fill_value(batch, DenseIndex(std::round(y)),
DenseIndex(std::round(x)), channel, fill_value);
}
return input_(input_coords);
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
bilinear_interpolation(const DenseIndex batch, const float y, const float x,
const DenseIndex channel, const T fill_value) const {
const float y_floor = std::floor(y);
const float x_floor = std::floor(x);
const float y_ceil = y_floor + 1;
const float x_ceil = x_floor + 1;
// f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
const float value_yfloor =
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor),
DenseIndex(x_floor), channel,
fill_value) +
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor),
DenseIndex(x_ceil), channel,
fill_value);
// f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
const float value_yceil =
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil),
DenseIndex(x_floor), channel,
fill_value) +
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil),
DenseIndex(x_ceil), channel,
fill_value);
// f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
// + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
const DenseIndex batch, const DenseIndex y, const DenseIndex x,
const DenseIndex channel, const T fill_value) const {
// batch and channel must be correct, because they are passed unchanged from
// the input.
return (0 <= y && y < input_.dimension(1) && 0 <= x &&
x < input_.dimension(2))
? input_(array<DenseIndex, 4>{batch, y, x, channel})
: fill_value;
}
};
@ -85,6 +135,7 @@ class ProjectiveGenerator {
// some Eigen device code.
namespace functor {
using generator::Interpolation;
using generator::ProjectiveGenerator;
template <typename Device, typename T>
@ -92,15 +143,17 @@ struct FillProjectiveTransform {
typedef typename TTypes<T, 4>::Tensor OutputType;
typedef typename TTypes<T, 4>::ConstTensor InputType;
typedef typename TTypes<float, 2>::ConstTensor TransformsType;
const Interpolation interpolation_;
FillProjectiveTransform() {}
FillProjectiveTransform(Interpolation interpolation)
: interpolation_(interpolation) {}
EIGEN_ALWAYS_INLINE
void operator()(const Device& device, OutputType* output,
const InputType& images,
const TransformsType& transform) const {
ProjectiveGenerator<Device, T> generator(images, transform);
output->device(device) = images.generate(generator);
output->device(device) = images.generate(
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
}
};

View File

@ -23,13 +23,13 @@ using shape_inference::InferenceContext;
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
// TODO(ringwalt): Add an "interpolation" argument with "none", "bilinear", etc.
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
// implement "same" and "valid" modes in the Python function.
REGISTER_OP("ImageProjectiveTransform")
.Input("images: dtype")
.Input("transforms: float32")
.Attr("dtype: {uint8, int32, int64, float32, float64}")
.Attr("interpolation: string")
.Output("transformed_images: dtype")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->input(0));

View File

@ -111,6 +111,55 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
[0, 1, 0, 1],
[0, 1, 1, 1]])
def test_bilinear(self):
with self.test_session():
image = constant_op.constant(
[[0, 0, 0, 0, 0],
[0, 1, 1, 1, 0],
[0, 1, 0, 1, 0],
[0, 1, 1, 1, 0],
[0, 0, 0, 0, 0]],
dtypes.float32)
# The following result matches:
# >>> scipy.ndimage.rotate(image, 45, order=1, reshape=False)
# which uses spline interpolation of order 1, equivalent to bilinear
# interpolation.
self.assertAllClose(
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
[[0.000, 0.000, 0.343, 0.000, 0.000],
[0.000, 0.586, 0.914, 0.586, 0.000],
[0.343, 0.914, 0.000, 0.914, 0.343],
[0.000, 0.586, 0.914, 0.586, 0.000],
[0.000, 0.000, 0.343, 0.000, 0.000]],
atol=0.001)
self.assertAllClose(
image_ops.rotate(image, np.pi / 4.0, interpolation="NEAREST").eval(),
[[0, 0, 1, 0, 0],
[0, 1, 1, 1, 0],
[1, 1, 0, 1, 1],
[0, 1, 1, 1, 0],
[0, 0, 1, 0, 0]])
def test_bilinear_uint8(self):
with self.test_session():
image = constant_op.constant(
np.asarray(
[[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 255, 255, 255, 0.0],
[0.0, 255, 0.0, 255, 0.0],
[0.0, 255, 255, 255, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]],
np.uint8),
dtypes.uint8)
# == np.rint((expected image above) * 255)
self.assertAllEqual(
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
[[0.0, 0.0, 87., 0.0, 0.0],
[0.0, 149, 233, 149, 0.0],
[87., 233, 0.0, 233, 87.],
[0.0, 149, 233, 149, 0.0],
[0.0, 0.0, 87., 0.0, 0.0]])
def _test_grad(self, shape_to_test):
with self.test_session():
test_image_shape = shape_to_test

View File

@ -37,7 +37,7 @@ _IMAGE_DTYPES = set(
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
def rotate(images, angles):
def rotate(images, angles, interpolation="NEAREST"):
"""Rotate image(s) by the passed angle(s) in radians.
Args:
@ -46,6 +46,7 @@ def rotate(images, angles):
(num_rows, num_columns) (HW).
angles: A scalar angle to rotate all images by, or (if images has rank 4)
a vector of length num_images, with an angle for each image in the batch.
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
Returns:
Image(s) with the same type and shape as `images`, rotated by the given
@ -70,7 +71,8 @@ def rotate(images, angles):
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
output = transform(
images,
angles_to_projective_transforms(angles, image_width, image_height))
angles_to_projective_transforms(angles, image_height, image_width),
interpolation=interpolation)
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
@ -120,7 +122,7 @@ def angles_to_projective_transforms(angles, image_height, image_width):
axis=1)
def transform(images, transforms):
def transform(images, transforms, interpolation="NEAREST"):
"""Applies the given transform(s) to the image(s).
Args:
@ -134,6 +136,7 @@ def transform(images, transforms):
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
the transform mapping input points to output points.
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
Returns:
Image(s) with the same type and shape as `images`, with the given
@ -163,8 +166,8 @@ def transform(images, transforms):
transforms = transform_or_transforms
else:
raise TypeError("Transforms should have rank 1 or 2.")
# pylint: disable=protected-access
output = gen_image_ops.image_projective_transform(images, transforms)
output = gen_image_ops.image_projective_transform(
images, transforms, interpolation=interpolation.upper())
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
@ -220,6 +223,7 @@ def _image_projective_transform_grad(op, grad):
"""Computes the gradient for ImageProjectiveTransform."""
images = op.inputs[0]
transforms = op.inputs[1]
interpolation = op.get_attr("interpolation")
image_or_images = ops.convert_to_tensor(images, name="images")
transform_or_transforms = ops.convert_to_tensor(
@ -246,7 +250,8 @@ def _image_projective_transform_grad(op, grad):
transforms = _flat_transforms_to_matrices(transforms=transforms)
inverse = linalg_ops.matrix_inverse(transforms)
transforms = _transform_matrices_to_flat(inverse)
output = gen_image_ops.image_projective_transform(grad, transforms)
output = gen_image_ops.image_projective_transform(
grad, transforms, interpolation=interpolation)
if len(image_or_images.get_shape()) == 2:
return [output[0, :, :, 0], None]
elif len(image_or_images.get_shape()) == 3: