Add bilinear interpolation to tf.contrib.image.
Change: 155247916
This commit is contained in:
parent
fe16da297c
commit
37e3b71b49
@ -43,13 +43,29 @@ template class FillProjectiveTransform<CPUDevice, double>;
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
using functor::FillProjectiveTransform;
|
||||
using generator::INTERPOLATION_BILINEAR;
|
||||
using generator::INTERPOLATION_NEAREST;
|
||||
using generator::Interpolation;
|
||||
using generator::ProjectiveGenerator;
|
||||
|
||||
template <typename Device, typename T>
|
||||
class ImageProjectiveTransform : public OpKernel {
|
||||
private:
|
||||
Interpolation interpolation_;
|
||||
|
||||
public:
|
||||
explicit ImageProjectiveTransform(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
string interpolation_str;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
|
||||
if (interpolation_str == "NEAREST") {
|
||||
interpolation_ = INTERPOLATION_NEAREST;
|
||||
} else if (interpolation_str == "BILINEAR") {
|
||||
interpolation_ = INTERPOLATION_BILINEAR;
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid interpolation " << interpolation_str
|
||||
<< ". Supported types: NEAREST, BILINEAR";
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& images_t = ctx->input(0);
|
||||
@ -68,8 +84,8 @@ class ImageProjectiveTransform : public OpKernel {
|
||||
Tensor* output_t;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
|
||||
auto output = output_t->tensor<T, 4>();
|
||||
const FillProjectiveTransform<Device, T> functor;
|
||||
functor(ctx->eigen_device<Device>(), &output, images, transform);
|
||||
(FillProjectiveTransform<Device, T>(interpolation_))(
|
||||
ctx->eigen_device<Device>(), &output, images, transform);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -28,6 +28,8 @@ namespace tensorflow {
|
||||
|
||||
namespace generator {
|
||||
|
||||
enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
|
||||
|
||||
using Eigen::array;
|
||||
using Eigen::DenseIndex;
|
||||
|
||||
@ -36,20 +38,19 @@ class ProjectiveGenerator {
|
||||
private:
|
||||
typename TTypes<T, 4>::ConstTensor input_;
|
||||
typename TTypes<float>::ConstMatrix transforms_;
|
||||
const Interpolation interpolation_;
|
||||
|
||||
public:
|
||||
static const int kNumParameters = 8;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<float>::ConstMatrix transforms)
|
||||
: input_(input), transforms_(transforms) {}
|
||||
typename TTypes<float>::ConstMatrix transforms,
|
||||
const Interpolation interpolation)
|
||||
: input_(input), transforms_(transforms), interpolation_(interpolation) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||
operator()(const array<DenseIndex, 4>& coords) const {
|
||||
array<DenseIndex, 4> input_coords;
|
||||
input_coords[0] = coords[0];
|
||||
|
||||
const int64 output_y = coords[1];
|
||||
const int64 output_x = coords[2];
|
||||
const float* transform =
|
||||
@ -57,24 +58,73 @@ class ProjectiveGenerator {
|
||||
? transforms_.data()
|
||||
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
|
||||
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
|
||||
const int64 input_x = std::round(
|
||||
const float input_x =
|
||||
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
|
||||
projection);
|
||||
const int64 input_y = std::round(
|
||||
projection;
|
||||
const float input_y =
|
||||
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
|
||||
projection);
|
||||
projection;
|
||||
|
||||
if (!(0 <= input_y && input_y < input_.dimension(1) && 0 <= input_x &&
|
||||
input_x < input_.dimension(2))) {
|
||||
// TODO(ringwalt): Add a fill value input.
|
||||
return T(0);
|
||||
// TODO(ringwalt): Add a fill value input.
|
||||
static const T fill_value = T(0);
|
||||
switch (interpolation_) {
|
||||
case INTERPOLATION_NEAREST:
|
||||
// Switch the order of x and y again for indexing into the image.
|
||||
return nearest_interpolation(coords[0], input_y, input_x, coords[3],
|
||||
fill_value);
|
||||
case INTERPOLATION_BILINEAR:
|
||||
return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
|
||||
fill_value);
|
||||
}
|
||||
input_coords[1] = input_y;
|
||||
input_coords[2] = input_x;
|
||||
}
|
||||
|
||||
input_coords[3] = coords[3];
|
||||
private:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||
nearest_interpolation(const DenseIndex batch, const float y, const float x,
|
||||
const DenseIndex channel, const T fill_value) const {
|
||||
return read_with_fill_value(batch, DenseIndex(std::round(y)),
|
||||
DenseIndex(std::round(x)), channel, fill_value);
|
||||
}
|
||||
|
||||
return input_(input_coords);
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||
bilinear_interpolation(const DenseIndex batch, const float y, const float x,
|
||||
const DenseIndex channel, const T fill_value) const {
|
||||
const float y_floor = std::floor(y);
|
||||
const float x_floor = std::floor(x);
|
||||
const float y_ceil = y_floor + 1;
|
||||
const float x_ceil = x_floor + 1;
|
||||
// f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
|
||||
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
|
||||
const float value_yfloor =
|
||||
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor),
|
||||
DenseIndex(x_floor), channel,
|
||||
fill_value) +
|
||||
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor),
|
||||
DenseIndex(x_ceil), channel,
|
||||
fill_value);
|
||||
// f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
|
||||
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
|
||||
const float value_yceil =
|
||||
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil),
|
||||
DenseIndex(x_floor), channel,
|
||||
fill_value) +
|
||||
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil),
|
||||
DenseIndex(x_ceil), channel,
|
||||
fill_value);
|
||||
// f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
|
||||
// + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
|
||||
return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
|
||||
const DenseIndex batch, const DenseIndex y, const DenseIndex x,
|
||||
const DenseIndex channel, const T fill_value) const {
|
||||
// batch and channel must be correct, because they are passed unchanged from
|
||||
// the input.
|
||||
return (0 <= y && y < input_.dimension(1) && 0 <= x &&
|
||||
x < input_.dimension(2))
|
||||
? input_(array<DenseIndex, 4>{batch, y, x, channel})
|
||||
: fill_value;
|
||||
}
|
||||
};
|
||||
|
||||
@ -85,6 +135,7 @@ class ProjectiveGenerator {
|
||||
// some Eigen device code.
|
||||
namespace functor {
|
||||
|
||||
using generator::Interpolation;
|
||||
using generator::ProjectiveGenerator;
|
||||
|
||||
template <typename Device, typename T>
|
||||
@ -92,15 +143,17 @@ struct FillProjectiveTransform {
|
||||
typedef typename TTypes<T, 4>::Tensor OutputType;
|
||||
typedef typename TTypes<T, 4>::ConstTensor InputType;
|
||||
typedef typename TTypes<float, 2>::ConstTensor TransformsType;
|
||||
const Interpolation interpolation_;
|
||||
|
||||
FillProjectiveTransform() {}
|
||||
FillProjectiveTransform(Interpolation interpolation)
|
||||
: interpolation_(interpolation) {}
|
||||
|
||||
EIGEN_ALWAYS_INLINE
|
||||
void operator()(const Device& device, OutputType* output,
|
||||
const InputType& images,
|
||||
const TransformsType& transform) const {
|
||||
ProjectiveGenerator<Device, T> generator(images, transform);
|
||||
output->device(device) = images.generate(generator);
|
||||
output->device(device) = images.generate(
|
||||
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -23,13 +23,13 @@ using shape_inference::InferenceContext;
|
||||
|
||||
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
|
||||
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
|
||||
// TODO(ringwalt): Add an "interpolation" argument with "none", "bilinear", etc.
|
||||
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
|
||||
// implement "same" and "valid" modes in the Python function.
|
||||
REGISTER_OP("ImageProjectiveTransform")
|
||||
.Input("images: dtype")
|
||||
.Input("transforms: float32")
|
||||
.Attr("dtype: {uint8, int32, int64, float32, float64}")
|
||||
.Attr("interpolation: string")
|
||||
.Output("transformed_images: dtype")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->input(0));
|
||||
|
@ -111,6 +111,55 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
|
||||
[0, 1, 0, 1],
|
||||
[0, 1, 1, 1]])
|
||||
|
||||
def test_bilinear(self):
|
||||
with self.test_session():
|
||||
image = constant_op.constant(
|
||||
[[0, 0, 0, 0, 0],
|
||||
[0, 1, 1, 1, 0],
|
||||
[0, 1, 0, 1, 0],
|
||||
[0, 1, 1, 1, 0],
|
||||
[0, 0, 0, 0, 0]],
|
||||
dtypes.float32)
|
||||
# The following result matches:
|
||||
# >>> scipy.ndimage.rotate(image, 45, order=1, reshape=False)
|
||||
# which uses spline interpolation of order 1, equivalent to bilinear
|
||||
# interpolation.
|
||||
self.assertAllClose(
|
||||
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
|
||||
[[0.000, 0.000, 0.343, 0.000, 0.000],
|
||||
[0.000, 0.586, 0.914, 0.586, 0.000],
|
||||
[0.343, 0.914, 0.000, 0.914, 0.343],
|
||||
[0.000, 0.586, 0.914, 0.586, 0.000],
|
||||
[0.000, 0.000, 0.343, 0.000, 0.000]],
|
||||
atol=0.001)
|
||||
self.assertAllClose(
|
||||
image_ops.rotate(image, np.pi / 4.0, interpolation="NEAREST").eval(),
|
||||
[[0, 0, 1, 0, 0],
|
||||
[0, 1, 1, 1, 0],
|
||||
[1, 1, 0, 1, 1],
|
||||
[0, 1, 1, 1, 0],
|
||||
[0, 0, 1, 0, 0]])
|
||||
|
||||
def test_bilinear_uint8(self):
|
||||
with self.test_session():
|
||||
image = constant_op.constant(
|
||||
np.asarray(
|
||||
[[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 255, 255, 255, 0.0],
|
||||
[0.0, 255, 0.0, 255, 0.0],
|
||||
[0.0, 255, 255, 255, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
np.uint8),
|
||||
dtypes.uint8)
|
||||
# == np.rint((expected image above) * 255)
|
||||
self.assertAllEqual(
|
||||
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
|
||||
[[0.0, 0.0, 87., 0.0, 0.0],
|
||||
[0.0, 149, 233, 149, 0.0],
|
||||
[87., 233, 0.0, 233, 87.],
|
||||
[0.0, 149, 233, 149, 0.0],
|
||||
[0.0, 0.0, 87., 0.0, 0.0]])
|
||||
|
||||
def _test_grad(self, shape_to_test):
|
||||
with self.test_session():
|
||||
test_image_shape = shape_to_test
|
||||
|
@ -37,7 +37,7 @@ _IMAGE_DTYPES = set(
|
||||
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
|
||||
|
||||
|
||||
def rotate(images, angles):
|
||||
def rotate(images, angles, interpolation="NEAREST"):
|
||||
"""Rotate image(s) by the passed angle(s) in radians.
|
||||
|
||||
Args:
|
||||
@ -46,6 +46,7 @@ def rotate(images, angles):
|
||||
(num_rows, num_columns) (HW).
|
||||
angles: A scalar angle to rotate all images by, or (if images has rank 4)
|
||||
a vector of length num_images, with an angle for each image in the batch.
|
||||
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
|
||||
|
||||
Returns:
|
||||
Image(s) with the same type and shape as `images`, rotated by the given
|
||||
@ -70,7 +71,8 @@ def rotate(images, angles):
|
||||
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
|
||||
output = transform(
|
||||
images,
|
||||
angles_to_projective_transforms(angles, image_width, image_height))
|
||||
angles_to_projective_transforms(angles, image_height, image_width),
|
||||
interpolation=interpolation)
|
||||
if len(image_or_images.get_shape()) == 2:
|
||||
return output[0, :, :, 0]
|
||||
elif len(image_or_images.get_shape()) == 3:
|
||||
@ -120,7 +122,7 @@ def angles_to_projective_transforms(angles, image_height, image_width):
|
||||
axis=1)
|
||||
|
||||
|
||||
def transform(images, transforms):
|
||||
def transform(images, transforms, interpolation="NEAREST"):
|
||||
"""Applies the given transform(s) to the image(s).
|
||||
|
||||
Args:
|
||||
@ -134,6 +136,7 @@ def transform(images, transforms):
|
||||
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
|
||||
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
|
||||
the transform mapping input points to output points.
|
||||
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
|
||||
|
||||
Returns:
|
||||
Image(s) with the same type and shape as `images`, with the given
|
||||
@ -163,8 +166,8 @@ def transform(images, transforms):
|
||||
transforms = transform_or_transforms
|
||||
else:
|
||||
raise TypeError("Transforms should have rank 1 or 2.")
|
||||
# pylint: disable=protected-access
|
||||
output = gen_image_ops.image_projective_transform(images, transforms)
|
||||
output = gen_image_ops.image_projective_transform(
|
||||
images, transforms, interpolation=interpolation.upper())
|
||||
if len(image_or_images.get_shape()) == 2:
|
||||
return output[0, :, :, 0]
|
||||
elif len(image_or_images.get_shape()) == 3:
|
||||
@ -220,6 +223,7 @@ def _image_projective_transform_grad(op, grad):
|
||||
"""Computes the gradient for ImageProjectiveTransform."""
|
||||
images = op.inputs[0]
|
||||
transforms = op.inputs[1]
|
||||
interpolation = op.get_attr("interpolation")
|
||||
|
||||
image_or_images = ops.convert_to_tensor(images, name="images")
|
||||
transform_or_transforms = ops.convert_to_tensor(
|
||||
@ -246,7 +250,8 @@ def _image_projective_transform_grad(op, grad):
|
||||
transforms = _flat_transforms_to_matrices(transforms=transforms)
|
||||
inverse = linalg_ops.matrix_inverse(transforms)
|
||||
transforms = _transform_matrices_to_flat(inverse)
|
||||
output = gen_image_ops.image_projective_transform(grad, transforms)
|
||||
output = gen_image_ops.image_projective_transform(
|
||||
grad, transforms, interpolation=interpolation)
|
||||
if len(image_or_images.get_shape()) == 2:
|
||||
return [output[0, :, :, 0], None]
|
||||
elif len(image_or_images.get_shape()) == 3:
|
||||
|
Loading…
Reference in New Issue
Block a user