Export fill mode for image_projective_transform

PiperOrigin-RevId: 300652382
Change-Id: I762197f7dfbd545445db7b3b330d463f5f66d856
This commit is contained in:
Zhenyu Tan 2020-03-12 16:51:31 -07:00 committed by TensorFlower Gardener
parent 3614f46e1f
commit 287f14529c
15 changed files with 516 additions and 92 deletions

View File

@ -38,6 +38,12 @@ END
name: "interpolation" name: "interpolation"
description: <<END description: <<END
Interpolation method, "NEAREST" or "BILINEAR". Interpolation method, "NEAREST" or "BILINEAR".
END
}
attr {
name: "fill_mode"
description: <<END
Fill mode, "REFLECT", "WRAP", or "CONSTANT".
END END
} }
summary: "Applies the given transform to each of the images." summary: "Applies the given transform to each of the images."

View File

@ -46,27 +46,39 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
using functor::FillProjectiveTransform; using functor::FillProjectiveTransform;
using generator::Interpolation; using generator::Interpolation;
using generator::INTERPOLATION_BILINEAR; using generator::Mode;
using generator::INTERPOLATION_NEAREST;
using generator::ProjectiveGenerator;
template <typename Device, typename T> template <typename Device, typename T>
class ImageProjectiveTransform : public OpKernel { class ImageProjectiveTransformV2 : public OpKernel {
private: private:
Interpolation interpolation_; Interpolation interpolation_;
Mode fill_mode_;
public: public:
explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) { explicit ImageProjectiveTransformV2(OpKernelConstruction* ctx)
: OpKernel(ctx) {
string interpolation_str; string interpolation_str;
OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str)); OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
if (interpolation_str == "NEAREST") { if (interpolation_str == "NEAREST") {
interpolation_ = INTERPOLATION_NEAREST; interpolation_ = Interpolation::NEAREST;
} else if (interpolation_str == "BILINEAR") { } else if (interpolation_str == "BILINEAR") {
interpolation_ = INTERPOLATION_BILINEAR; interpolation_ = Interpolation::BILINEAR;
} else { } else {
LOG(ERROR) << "Invalid interpolation " << interpolation_str LOG(ERROR) << "Invalid interpolation " << interpolation_str
<< ". Supported types: NEAREST, BILINEAR"; << ". Supported types: NEAREST, BILINEAR";
} }
string mode_str;
OP_REQUIRES_OK(ctx, ctx->GetAttr("fill_mode", &mode_str));
if (mode_str == "REFLECT") {
fill_mode_ = Mode::REFLECT;
} else if (mode_str == "WRAP") {
fill_mode_ = Mode::WRAP;
} else if (mode_str == "CONSTANT") {
fill_mode_ = Mode::CONSTANT;
} else {
LOG(ERROR) << "Invalid mode " << mode_str
<< ". Supported types: REFLECT, WRAP, CONSTANT";
}
} }
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
@ -78,8 +90,7 @@ class ImageProjectiveTransform : public OpKernel {
(TensorShapeUtils::IsMatrix(transform_t.shape()) && (TensorShapeUtils::IsMatrix(transform_t.shape()) &&
(transform_t.dim_size(0) == images_t.dim_size(0) || (transform_t.dim_size(0) == images_t.dim_size(0) ||
transform_t.dim_size(0) == 1) && transform_t.dim_size(0) == 1) &&
transform_t.dim_size(1) == transform_t.dim_size(1) == 8),
ProjectiveGenerator<Device, T>::kNumParameters),
errors::InvalidArgument( errors::InvalidArgument(
"Input transform should be num_images x 8 or 1 x 8")); "Input transform should be num_images x 8 or 1 x 8"));
@ -116,15 +127,15 @@ class ImageProjectiveTransform : public OpKernel {
auto transform = transform_t.matrix<float>(); auto transform = transform_t.matrix<float>();
(FillProjectiveTransform<Device, T>(interpolation_))( (FillProjectiveTransform<Device, T>(interpolation_))(
ctx->eigen_device<Device>(), &output, images, transform); ctx->eigen_device<Device>(), &output, images, transform, fill_mode_);
} }
}; };
#define REGISTER(TYPE) \ #define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("dtype"), \ .TypeConstraint<TYPE>("dtype"), \
ImageProjectiveTransform<CPUDevice, TYPE>) ImageProjectiveTransformV2<CPUDevice, TYPE>)
TF_CALL_uint8(REGISTER); TF_CALL_uint8(REGISTER);
TF_CALL_int32(REGISTER); TF_CALL_int32(REGISTER);
@ -138,33 +149,48 @@ TF_CALL_double(REGISTER);
#if GOOGLE_CUDA #if GOOGLE_CUDA
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;
typedef generator::Mode Mode;
namespace functor { namespace functor {
// NOTE(ringwalt): We get an undefined symbol error if we don't explicitly // NOTE(ringwalt): We get an undefined symbol error if we don't explicitly
// instantiate the operator() in GCC'd code. // instantiate the operator() in GCC'd code.
#define DECLARE_FUNCTOR(TYPE) \ #define DECLARE_PROJECT_FUNCTOR(TYPE) \
template <> \ template <> \
void FillProjectiveTransform<GPUDevice, TYPE>::operator()( \ void FillProjectiveTransform<GPUDevice, TYPE>::operator()( \
const GPUDevice& device, OutputType* output, const InputType& images, \ const GPUDevice& device, OutputType* output, const InputType& images, \
const TransformsType& transform) const; \ const TransformsType& transform, const Mode fill_mode) const; \
extern template struct FillProjectiveTransform<GPUDevice, TYPE> extern template struct FillProjectiveTransform<GPUDevice, TYPE>
TF_CALL_uint8(DECLARE_FUNCTOR); TF_CALL_uint8(DECLARE_PROJECT_FUNCTOR);
TF_CALL_int32(DECLARE_FUNCTOR); TF_CALL_int32(DECLARE_PROJECT_FUNCTOR);
TF_CALL_int64(DECLARE_FUNCTOR); TF_CALL_int64(DECLARE_PROJECT_FUNCTOR);
TF_CALL_half(DECLARE_FUNCTOR); TF_CALL_half(DECLARE_PROJECT_FUNCTOR);
TF_CALL_float(DECLARE_FUNCTOR); TF_CALL_float(DECLARE_PROJECT_FUNCTOR);
TF_CALL_double(DECLARE_FUNCTOR); TF_CALL_double(DECLARE_PROJECT_FUNCTOR);
} // end namespace functor } // end namespace functor
#define REGISTER(TYPE) \ namespace generator {
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
.Device(DEVICE_GPU) \ #define DECLARE_MAP_FUNCTOR(Mode) \
.TypeConstraint<TYPE>("dtype") \ template <> \
.HostMemory("output_shape"), \ float MapCoordinate<GPUDevice, Mode>::operator()(const float out_coord, \
ImageProjectiveTransform<GPUDevice, TYPE>) const DenseIndex len); \
extern template struct MapCoordinate<GPUDevice, Mode>
DECLARE_MAP_FUNCTOR(Mode::REFLECT);
DECLARE_MAP_FUNCTOR(Mode::WRAP);
DECLARE_MAP_FUNCTOR(Mode::CONSTANT);
} // end namespace generator
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
.Device(DEVICE_GPU) \
.TypeConstraint<TYPE>("dtype") \
.HostMemory("output_shape"), \
ImageProjectiveTransformV2<GPUDevice, TYPE>)
TF_CALL_uint8(REGISTER); TF_CALL_uint8(REGISTER);
TF_CALL_int32(REGISTER); TF_CALL_int32(REGISTER);

View File

@ -29,12 +29,67 @@ namespace tensorflow {
namespace generator { namespace generator {
enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR }; enum Interpolation { NEAREST, BILINEAR };
enum Mode { REFLECT, WRAP, CONSTANT };
using Eigen::array; using Eigen::array;
using Eigen::DenseIndex; using Eigen::DenseIndex;
template <typename Device, typename T> template <typename Device, Mode M>
struct MapCoordinate {
float operator()(const float out_coord, const DenseIndex len);
};
template <typename Device>
struct MapCoordinate<Device, Mode::REFLECT> {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord,
const DenseIndex len) {
float in_coord = out_coord;
// Reflect [abcd] to [dcba|abcd|dcba], periodically from [0, 2 * len)
// over [abcddcba]
const DenseIndex boundary = 2 * len;
// Shift coordinate to (-boundary, boundary)
in_coord -= boundary * static_cast<DenseIndex>(in_coord / boundary);
// Convert negative coordinates from [-boundary, 0) to [0, boundary)
if (in_coord < 0) {
in_coord += boundary;
}
// Coordinate in_coord between [len, boundary) should reverse reflect
// to coordinate to (bounary - 1 - in_coord) between [0, len)
if (in_coord > len - 1) {
in_coord = boundary - 1 - in_coord;
}
return in_coord;
}
};
template <typename Device>
struct MapCoordinate<Device, Mode::WRAP> {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord,
const DenseIndex len) {
float in_coord = out_coord;
// Wrap [abcd] to [abcd|abcd|abcd], periodically from [0, len)
// over [abcd]
const DenseIndex boundary = len;
// Shift coordinate to (-boundary, boundary)
in_coord -= boundary * static_cast<DenseIndex>(in_coord / boundary);
// Shift negative coordinate from [-boundary, 0) to [0, boundary)
if (in_coord < 0) {
in_coord += boundary;
}
return in_coord;
}
};
template <typename Device>
struct MapCoordinate<Device, Mode::CONSTANT> {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord,
const DenseIndex len) {
return out_coord;
}
};
template <typename Device, typename T, Mode M>
class ProjectiveGenerator { class ProjectiveGenerator {
private: private:
typename TTypes<T, 4>::ConstTensor input_; typename TTypes<T, 4>::ConstTensor input_;
@ -42,8 +97,6 @@ class ProjectiveGenerator {
const Interpolation interpolation_; const Interpolation interpolation_;
public: public:
static const int kNumParameters = 8;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input, ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
typename TTypes<float>::ConstMatrix transforms, typename TTypes<float>::ConstMatrix transforms,
@ -52,6 +105,7 @@ class ProjectiveGenerator {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const array<DenseIndex, 4>& coords) const { operator()(const array<DenseIndex, 4>& coords) const {
const T fill_value = T(0);
const int64 output_y = coords[1]; const int64 output_y = coords[1];
const int64 output_x = coords[2]; const int64 output_x = coords[2];
const float* transform = const float* transform =
@ -62,7 +116,7 @@ class ProjectiveGenerator {
if (projection == 0) { if (projection == 0) {
// Return the fill value (0) for infinite coordinates, // Return the fill value (0) for infinite coordinates,
// which are outside the input image // which are outside the input image
return T(0); return fill_value;
} }
const float input_x = const float input_x =
(transform[0] * output_x + transform[1] * output_y + transform[2]) / (transform[0] * output_x + transform[1] * output_y + transform[2]) /
@ -71,22 +125,24 @@ class ProjectiveGenerator {
(transform[3] * output_x + transform[4] * output_y + transform[5]) / (transform[3] * output_x + transform[4] * output_y + transform[5]) /
projection; projection;
const T fill_value = T(0); // Map out-of-boundary input coordinates to in-boundary based on fill_mode.
auto map_functor = MapCoordinate<Device, M>();
const float x = map_functor(input_x, input_.dimension(2));
const float y = map_functor(input_y, input_.dimension(1));
const DenseIndex batch = coords[0];
const DenseIndex channels = coords[3];
switch (interpolation_) { switch (interpolation_) {
case INTERPOLATION_NEAREST: case NEAREST:
// Switch the order of x and y again for indexing into the image. return nearest_interpolation(batch, y, x, channels, fill_value);
return nearest_interpolation(coords[0], input_y, input_x, coords[3], case BILINEAR:
fill_value); return bilinear_interpolation(batch, y, x, channels, fill_value);
case INTERPOLATION_BILINEAR:
return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
fill_value);
} }
// Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST // Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST
// or INTERPOLATION_BILINEAR. // or INTERPOLATION_BILINEAR.
return T(0); return fill_value;
} }
private:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
nearest_interpolation(const DenseIndex batch, const float y, const float x, nearest_interpolation(const DenseIndex batch, const float y, const float x,
const DenseIndex channel, const T fill_value) const { const DenseIndex channel, const T fill_value) const {
@ -138,12 +194,10 @@ class ProjectiveGenerator {
} // end namespace generator } // end namespace generator
// NOTE(ringwalt): We MUST wrap the generate() call in a functor and explicitly
// instantiate the functor in image_ops_gpu.cu.cc. Otherwise, we will be missing
// some Eigen device code.
namespace functor { namespace functor {
using generator::Interpolation; using generator::Interpolation;
using generator::Mode;
using generator::ProjectiveGenerator; using generator::ProjectiveGenerator;
template <typename Device, typename T> template <typename Device, typename T>
@ -151,17 +205,32 @@ struct FillProjectiveTransform {
typedef typename TTypes<T, 4>::Tensor OutputType; typedef typename TTypes<T, 4>::Tensor OutputType;
typedef typename TTypes<T, 4>::ConstTensor InputType; typedef typename TTypes<T, 4>::ConstTensor InputType;
typedef typename TTypes<float, 2>::ConstTensor TransformsType; typedef typename TTypes<float, 2>::ConstTensor TransformsType;
const Interpolation interpolation_; const Interpolation interpolation;
FillProjectiveTransform(Interpolation interpolation) explicit FillProjectiveTransform(Interpolation interpolation)
: interpolation_(interpolation) {} : interpolation(interpolation) {}
EIGEN_ALWAYS_INLINE EIGEN_ALWAYS_INLINE
void operator()(const Device& device, OutputType* output, void operator()(const Device& device, OutputType* output,
const InputType& images, const InputType& images, const TransformsType& transform,
const TransformsType& transform) const { const Mode fill_mode) const {
output->device(device) = output->generate( switch (fill_mode) {
ProjectiveGenerator<Device, T>(images, transform, interpolation_)); case Mode::REFLECT:
output->device(device) =
output->generate(ProjectiveGenerator<Device, T, Mode::REFLECT>(
images, transform, interpolation));
break;
case Mode::WRAP:
output->device(device) =
output->generate(ProjectiveGenerator<Device, T, Mode::WRAP>(
images, transform, interpolation));
break;
case Mode::CONSTANT:
output->device(device) =
output->generate(ProjectiveGenerator<Device, T, Mode::CONSTANT>(
images, transform, interpolation));
break;
}
} }
}; };

View File

@ -1028,7 +1028,6 @@ REGISTER_OP("GenerateBoundingBoxProposals")
return Status::OK(); return Status::OK();
}); });
// TODO(ringwalt): Add a "fill_mode" attr with "constant", "mirror", etc.
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0). // TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
// V2 op supports output_shape. V1 op is in contrib. // V2 op supports output_shape. V1 op is in contrib.
REGISTER_OP("ImageProjectiveTransformV2") REGISTER_OP("ImageProjectiveTransformV2")
@ -1037,6 +1036,7 @@ REGISTER_OP("ImageProjectiveTransformV2")
.Input("output_shape: int32") .Input("output_shape: int32")
.Attr("dtype: {uint8, int32, int64, float16, float32, float64}") .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
.Attr("interpolation: string") .Attr("interpolation: string")
.Attr("fill_mode: string = 'CONSTANT'")
.Output("transformed_images: dtype") .Output("transformed_images: dtype")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
ShapeHandle input; ShapeHandle input;

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -418,9 +419,15 @@ class RandomTranslation(Layer):
When represented as a single float, this value is used for both the upper When represented as a single float, this value is used for both the upper
and lower bound. and lower bound.
fill_mode: Points outside the boundaries of the input are filled according fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'nearest', 'bilinear'}`). to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
fill_value: Value used for points outside the boundaries of the input if - *reflect*: `(d c b a | a b c d | d c b a)`
`mode='constant'`. The input is extended by reflecting about the edge of the last pixel.
- *constant*: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond the edge with the
same constant value k = 0.
- *wrap*: `(a b c d | a b c d | a b c d)`
The input is extended by wrapping around to the opposite edge.
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
seed: Integer. Used to create a random seed. seed: Integer. Used to create a random seed.
name: A string, the name of the layer. name: A string, the name of the layer.
@ -440,8 +447,8 @@ class RandomTranslation(Layer):
def __init__(self, def __init__(self,
height_factor, height_factor,
width_factor, width_factor,
fill_mode='nearest', fill_mode='reflect',
fill_value=0., interpolation='bilinear',
seed=None, seed=None,
name=None, name=None,
**kwargs): **kwargs):
@ -471,11 +478,16 @@ class RandomTranslation(Layer):
raise ValueError('`width_factor` must have values between [-1, 1], ' raise ValueError('`width_factor` must have values between [-1, 1], '
'got {}'.format(width_factor)) 'got {}'.format(width_factor))
if fill_mode not in {'nearest', 'bilinear'}: if fill_mode not in {'reflect', 'wrap', 'constant'}:
raise NotImplementedError( raise NotImplementedError(
'`fill_mode` {} is not supported yet.'.format(fill_mode)) 'Unknown `fill_mode` {}. Only `reflect`, `wrap` and '
'`constant` are supported.'.format(fill_mode))
if interpolation not in {'nearest', 'bilinear'}:
raise NotImplementedError(
'Unknown `interpolation` {}. Only `nearest` and '
'`bilinear` are supported.'.format(interpolation))
self.fill_mode = fill_mode self.fill_mode = fill_mode
self.fill_value = fill_value self.interpolation = interpolation
self.seed = seed self.seed = seed
self._rng = make_generator(self.seed) self._rng = make_generator(self.seed)
self.input_spec = InputSpec(ndim=4) self.input_spec = InputSpec(ndim=4)
@ -508,7 +520,8 @@ class RandomTranslation(Layer):
return transform( return transform(
inputs, inputs,
get_translation_matrix(translations), get_translation_matrix(translations),
interpolation=self.fill_mode) interpolation=self.interpolation,
fill_mode=self.fill_mode)
output = tf_utils.smart_cond(training, random_translated_inputs, output = tf_utils.smart_cond(training, random_translated_inputs,
lambda: inputs) lambda: inputs)
@ -523,7 +536,7 @@ class RandomTranslation(Layer):
'height_factor': self.height_factor, 'height_factor': self.height_factor,
'width_factor': self.width_factor, 'width_factor': self.width_factor,
'fill_mode': self.fill_mode, 'fill_mode': self.fill_mode,
'fill_value': self.fill_value, 'interpolation': self.interpolation,
'seed': self.seed, 'seed': self.seed,
} }
base_config = super(RandomTranslation, self).get_config() base_config = super(RandomTranslation, self).get_config()
@ -565,7 +578,8 @@ def get_translation_matrix(translations, name=None):
def transform(images, def transform(images,
transforms, transforms,
interpolation='nearest', fill_mode='reflect',
interpolation='bilinear',
output_shape=None, output_shape=None,
name=None): name=None):
"""Applies the given transform(s) to the image(s). """Applies the given transform(s) to the image(s).
@ -582,11 +596,33 @@ def transform(images,
`k = c0 x + c1 y + 1`. The transforms are *inverted* compared to the `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to the
transform mapping input points to output points. Note that gradients are transform mapping input points to output points. Note that gradients are
not backpropagated into transformation parameters. not backpropagated into transformation parameters.
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
output_shape: Output dimesion after the transform, [height, width]. If None, output_shape: Output dimesion after the transform, [height, width]. If None,
output is the same size as input image. output is the same size as input image.
name: The name of the op. name: The name of the op.
## Fill mode.
Behavior for each valid value is as follows:
reflect (d c b a | a b c d | d c b a)
The input is extended by reflecting about the edge of the last pixel.
constant (k k k k | a b c d | k k k k)
The input is extended by filling all values beyond the edge with the same
constant value k = 0.
wrap (a b c d | a b c d | a b c d)
The input is extended by wrapping around to the opposite edge.
Input shape:
4D tensor with shape: `(samples, height, width, channels)`,
data_format='channels_last'.
Output shape:
4D tensor with shape: `(samples, height, width, channels)`,
data_format='channels_last'.
Returns: Returns:
Image(s) with the same type and shape as `images`, with the given Image(s) with the same type and shape as `images`, with the given
transform(s) applied. Transformed coordinates outside of the input image transform(s) applied. Transformed coordinates outside of the input image
@ -612,6 +648,13 @@ def transform(images,
'new_height, new_width, instead got ' 'new_height, new_width, instead got '
'{}'.format(output_shape)) '{}'.format(output_shape))
if compat.forward_compatible(2020, 3, 25):
return image_ops.image_projective_transform_v2(
images,
output_shape=output_shape,
transforms=transforms,
fill_mode=fill_mode.upper(),
interpolation=interpolation.upper())
return image_ops.image_projective_transform_v2( return image_ops.image_projective_transform_v2(
images, images,
output_shape=output_shape, output_shape=output_shape,
@ -680,11 +723,24 @@ class RandomRotation(Layer):
2 representing lower and upper bound for rotating clockwise and 2 representing lower and upper bound for rotating clockwise and
counter-clockwise. When represented as a single float, lower = upper. counter-clockwise. When represented as a single float, lower = upper.
fill_mode: Points outside the boundaries of the input are filled according fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'constant', 'nearest', 'bilinear', 'reflect', to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
'wrap'}`). - *reflect*: `(d c b a | a b c d | d c b a)`
The input is extended by reflecting about the edge of the last pixel.
- *constant*: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond the edge with the
same constant value k = 0.
- *wrap*: `(a b c d | a b c d | a b c d)`
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
seed: Integer. Used to create a random seed. seed: Integer. Used to create a random seed.
name: A string, the name of the layer. name: A string, the name of the layer.
Input shape:
4D tensor with shape: `(samples, height, width, channels)`,
data_format='channels_last'.
Output shape:
4D tensor with shape: `(samples, height, width, channels)`,
data_format='channels_last'.
Raise: Raise:
ValueError: if lower bound is not between [0, 1], or upper bound is ValueError: if lower bound is not between [0, 1], or upper bound is
negative. negative.
@ -692,7 +748,8 @@ class RandomRotation(Layer):
def __init__(self, def __init__(self,
factor, factor,
fill_mode='nearest', fill_mode='reflect',
interpolation='bilinear',
seed=None, seed=None,
name=None, name=None,
**kwargs): **kwargs):
@ -705,10 +762,16 @@ class RandomRotation(Layer):
if self.lower < 0. or self.upper < 0.: if self.lower < 0. or self.upper < 0.:
raise ValueError('Factor cannot have negative values, ' raise ValueError('Factor cannot have negative values, '
'got {}'.format(factor)) 'got {}'.format(factor))
if fill_mode not in {'nearest', 'bilinear'}: if fill_mode not in {'reflect', 'wrap', 'constant'}:
raise NotImplementedError( raise NotImplementedError(
'`fill_mode` {} is not supported yet.'.format(fill_mode)) 'Unknown `fill_mode` {}. Only `reflect`, `wrap` and '
'`constant` are supported.'.format(fill_mode))
if interpolation not in {'nearest', 'bilinear'}:
raise NotImplementedError(
'Unknown `interpolation` {}. Only `nearest` and '
'`bilinear` are supported.'.format(interpolation))
self.fill_mode = fill_mode self.fill_mode = fill_mode
self.interpolation = interpolation
self.seed = seed self.seed = seed
self._rng = make_generator(self.seed) self._rng = make_generator(self.seed)
self.input_spec = InputSpec(ndim=4) self.input_spec = InputSpec(ndim=4)
@ -732,7 +795,8 @@ class RandomRotation(Layer):
return transform( return transform(
inputs, inputs,
get_rotation_matrix(angles, img_hd, img_wd), get_rotation_matrix(angles, img_hd, img_wd),
interpolation=self.fill_mode) fill_mode=self.fill_mode,
interpolation=self.interpolation)
output = tf_utils.smart_cond(training, random_rotated_inputs, output = tf_utils.smart_cond(training, random_rotated_inputs,
lambda: inputs) lambda: inputs)
@ -746,6 +810,7 @@ class RandomRotation(Layer):
config = { config = {
'factor': self.factor, 'factor': self.factor,
'fill_mode': self.fill_mode, 'fill_mode': self.fill_mode,
'interpolation': self.interpolation,
'seed': self.seed, 'seed': self.seed,
} }
base_config = super(RandomRotation, self).get_config() base_config = super(RandomRotation, self).get_config()
@ -768,9 +833,14 @@ class RandomZoom(Layer):
upper and lower bound. For instance, `width_factor=(0.2, 0.3)` result in upper and lower bound. For instance, `width_factor=(0.2, 0.3)` result in
an output zoom varying in the range `[original * 20%, original * 30%]`. an output zoom varying in the range `[original * 20%, original * 30%]`.
fill_mode: Points outside the boundaries of the input are filled according fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'nearest', 'bilinear'}`). to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
fill_value: Value used for points outside the boundaries of the input if - *reflect*: `(d c b a | a b c d | d c b a)`
`mode='constant'`. The input is extended by reflecting about the edge of the last pixel.
- *constant*: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond the edge with the
same constant value k = 0.
- *wrap*: `(a b c d | a b c d | a b c d)`
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
seed: Integer. Used to create a random seed. seed: Integer. Used to create a random seed.
name: A string, the name of the layer. name: A string, the name of the layer.
@ -790,8 +860,8 @@ class RandomZoom(Layer):
def __init__(self, def __init__(self,
height_factor, height_factor,
width_factor, width_factor,
fill_mode='nearest', fill_mode='reflect',
fill_value=0., interpolation='bilinear',
seed=None, seed=None,
name=None, name=None,
**kwargs): **kwargs):
@ -821,11 +891,16 @@ class RandomZoom(Layer):
raise ValueError('`width_factor` cannot have lower bound larger than ' raise ValueError('`width_factor` cannot have lower bound larger than '
'upper bound, got {}.'.format(width_factor)) 'upper bound, got {}.'.format(width_factor))
if fill_mode not in {'nearest', 'bilinear'}: if fill_mode not in {'reflect', 'wrap', 'constant'}:
raise NotImplementedError( raise NotImplementedError(
'`fill_mode` {} is not supported yet.'.format(fill_mode)) 'Unknown `fill_mode` {}. Only `reflect`, `wrap` and '
'`constant` are supported.'.format(fill_mode))
if interpolation not in {'nearest', 'bilinear'}:
raise NotImplementedError(
'Unknown `interpolation` {}. Only `nearest` and '
'`bilinear` are supported.'.format(interpolation))
self.fill_mode = fill_mode self.fill_mode = fill_mode
self.fill_value = fill_value self.interpolation = interpolation
self.seed = seed self.seed = seed
self._rng = make_generator(self.seed) self._rng = make_generator(self.seed)
self.input_spec = InputSpec(ndim=4) self.input_spec = InputSpec(ndim=4)
@ -857,7 +932,8 @@ class RandomZoom(Layer):
dtype=inputs.dtype) dtype=inputs.dtype)
return transform( return transform(
inputs, get_zoom_matrix(zooms, img_hd, img_wd), inputs, get_zoom_matrix(zooms, img_hd, img_wd),
interpolation=self.fill_mode) fill_mode=self.fill_mode,
interpolation=self.interpolation)
output = tf_utils.smart_cond(training, random_zoomed_inputs, output = tf_utils.smart_cond(training, random_zoomed_inputs,
lambda: inputs) lambda: inputs)
@ -872,7 +948,7 @@ class RandomZoom(Layer):
'height_factor': self.height_factor, 'height_factor': self.height_factor,
'width_factor': self.width_factor, 'width_factor': self.width_factor,
'fill_mode': self.fill_mode, 'fill_mode': self.fill_mode,
'fill_value': self.fill_value, 'interpolation': self.interpolation,
'seed': self.seed, 'seed': self.seed,
} }
base_config = super(RandomZoom, self).get_config() base_config = super(RandomZoom, self).get_config()

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
@ -509,6 +510,250 @@ class RandomTranslationTest(keras_parameterized.TestCase):
self.assertEqual(layer_1.name, layer.name) self.assertEqual(layer_1.name, layer.name)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class RandomTransformTest(keras_parameterized.TestCase):
def _run_random_transform_with_mock(self,
transform_matrix,
expected_output,
mode,
interpolation='bilinear'):
inp = np.arange(15).reshape((1, 5, 3, 1)).astype(np.float32)
with self.cached_session(use_gpu=True):
output = image_preprocessing.transform(
inp, transform_matrix, fill_mode=mode, interpolation=interpolation)
self.assertAllClose(expected_output, output)
def test_random_translation_reflect(self):
# reflected output is (dcba|abcd|dcba)
if compat.forward_compatible(2020, 3, 25):
# Test down shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[0., 1., 2.],
[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8],
[9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'reflect')
# Test up shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[3., 4., 5.],
[6., 7., 8],
[9., 10., 11.],
[12., 13., 14.],
[12., 13., 14.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'reflect')
# Test left shift by 1.
# reflected output is (dcba|abcd|dcba)
# pyformat: disable
expected_output = np.asarray(
[[1., 2., 2.],
[4., 5., 5.],
[7., 8., 8.],
[10., 11., 11.],
[13., 14., 14.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'reflect')
# Test right shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[0., 0., 1.],
[3., 3., 4],
[6., 6., 7.],
[9., 9., 10.],
[12., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'reflect')
def test_random_translation_wrap(self):
# warpped output is (abcd|abcd|abcd)
if compat.forward_compatible(2020, 3, 25):
# Test down shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[12., 13., 14.],
[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8],
[9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'wrap')
# Test up shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[3., 4., 5.],
[6., 7., 8],
[9., 10., 11.],
[12., 13., 14.],
[0., 1., 2.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'wrap')
# Test left shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[1., 2., 0.],
[4., 5., 3.],
[7., 8., 6.],
[10., 11., 9.],
[13., 14., 12.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'wrap')
# Test right shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[2., 0., 1.],
[5., 3., 4],
[8., 6., 7.],
[11., 9., 10.],
[14., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'wrap')
def test_random_translation_constant(self):
# constant output is (0000|abcd|0000)
if compat.forward_compatible(2020, 3, 25):
# Test down shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[0., 0., 0.],
[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8],
[9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'constant')
# Test up shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[3., 4., 5.],
[6., 7., 8],
[9., 10., 11.],
[12., 13., 14.],
[0., 0., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'constant')
# Test left shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[1., 2., 0.],
[4., 5., 0.],
[7., 8., 0.],
[10., 11., 0.],
[13., 14., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'constant')
# Test right shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[0., 0., 1.],
[0., 3., 4],
[0., 6., 7.],
[0., 9., 10.],
[0., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'constant')
def test_random_translation_nearest_interpolation(self):
# nearest output is (aaaa|abcd|dddd)
if compat.forward_compatible(2020, 3, 25):
# Test down shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[0., 0., 0.],
[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8],
[9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]])
self._run_random_transform_with_mock(
transform_matrix, expected_output,
mode='constant', interpolation='nearest')
# Test up shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[3., 4., 5.],
[6., 7., 8],
[9., 10., 11.],
[12., 13., 14.],
[0., 0., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]])
self._run_random_transform_with_mock(
transform_matrix, expected_output,
mode='constant', interpolation='nearest')
# Test left shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[1., 2., 0.],
[4., 5., 0.],
[7., 8., 0.],
[10., 11., 0.],
[13., 14., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]])
self._run_random_transform_with_mock(
transform_matrix, expected_output,
mode='constant', interpolation='nearest')
# Test right shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[0., 0., 1.],
[0., 3., 4],
[0., 6., 7.],
[0., 9., 10.],
[0., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]])
self._run_random_transform_with_mock(
transform_matrix, expected_output,
mode='constant', interpolation='nearest')
@keras_parameterized.run_all_keras_modes(always_skip_v1=True) @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class RandomRotationTest(keras_parameterized.TestCase): class RandomRotationTest(keras_parameterized.TestCase):

View File

@ -245,6 +245,7 @@ def _image_projective_transform_grad(op, grad):
images = op.inputs[0] images = op.inputs[0]
transforms = op.inputs[1] transforms = op.inputs[1]
interpolation = op.get_attr("interpolation") interpolation = op.get_attr("interpolation")
fill_mode = op.get_attr("fill_mode")
image_or_images = ops.convert_to_tensor(images, name="images") image_or_images = ops.convert_to_tensor(images, name="images")
transform_or_transforms = ops.convert_to_tensor( transform_or_transforms = ops.convert_to_tensor(
@ -267,5 +268,6 @@ def _image_projective_transform_grad(op, grad):
images=grad, images=grad,
transforms=transforms, transforms=transforms,
output_shape=array_ops.shape(image_or_images)[1:3], output_shape=array_ops.shape(image_or_images)[1:3],
interpolation=interpolation) interpolation=interpolation,
fill_mode=fill_mode)
return [output, None, None] return [output, None, None]

View File

@ -113,7 +113,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'None\', \'None\'], " argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -113,7 +113,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], " argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -113,7 +113,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], " argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -1778,7 +1778,7 @@ tf_module {
} }
member_method { member_method {
name: "ImageProjectiveTransformV2" name: "ImageProjectiveTransformV2"
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
} }
member_method { member_method {
name: "ImageSummary" name: "ImageSummary"

View File

@ -113,7 +113,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'None\', \'None\'], " argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -113,7 +113,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], " argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -113,7 +113,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], " argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "add_loss" name: "add_loss"

View File

@ -1778,7 +1778,7 @@ tf_module {
} }
member_method { member_method {
name: "ImageProjectiveTransformV2" name: "ImageProjectiveTransformV2"
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
} }
member_method { member_method {
name: "ImageSummary" name: "ImageSummary"