Support fill_value for ImageProjectiveTransform

Fix conflict

Fix typo

Update number
This commit is contained in:
Tzu-Wei Sung 2020-08-06 11:29:39 -07:00
parent 9f86089e45
commit 13a9df90f4
17 changed files with 400 additions and 74 deletions

View File

@ -0,0 +1,63 @@
op {
graph_op_name: "ImageProjectiveTransformV3"
visibility: HIDDEN
in_arg {
name: "images"
description: <<END
4-D with shape `[batch, height, width, channels]`.
END
}
in_arg {
name: "transforms"
description: <<END
2-D Tensor, `[batch, 8]` or `[1, 8]` matrix, where each row corresponds to a 3 x 3
projective transformation matrix, with the last entry assumed to be 1. If there
is one row, the same transformation will be applied to all images.
END
}
in_arg {
name: "output_shape"
description: <<END
1-D Tensor [new_height, new_width].
END
}
in_arg {
name: "fill_value"
description: <<END
float, the value to be filled when fill_mode is constant".
END
}
out_arg {
name: "transformed_images"
description: <<END
4-D with shape
`[batch, new_height, new_width, channels]`.
END
}
attr {
name: "dtype"
description: <<END
Input dtype.
END
}
attr {
name: "interpolation"
description: <<END
Interpolation method, "NEAREST" or "BILINEAR".
END
}
attr {
name: "fill_mode"
description: <<END
Fill mode, "REFLECT", "WRAP", or "CONSTANT".
END
}
summary: "Applies the given transform to each of the images."
description: <<END
If one row of `transforms` is `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps
the *output* point `(x, y)` to a transformed *input* point
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where
`k = c0 x + c1 y + 1`. If the transformed point lays outside of the input
image, the output pixel is set to fill_value.
END
}

View File

@ -48,6 +48,68 @@ using functor::FillProjectiveTransform;
using generator::Interpolation; using generator::Interpolation;
using generator::Mode; using generator::Mode;
template <typename Device, typename T>
void DoImageProjectiveTransformOp(OpKernelContext* ctx,
const Interpolation& interpolation,
const Mode& fill_mode) {
const Tensor& images_t = ctx->input(0);
const Tensor& transform_t = ctx->input(1);
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
errors::InvalidArgument("Input images must have rank 4"));
OP_REQUIRES(ctx,
(TensorShapeUtils::IsMatrix(transform_t.shape()) &&
(transform_t.dim_size(0) == images_t.dim_size(0) ||
transform_t.dim_size(0) == 1) &&
transform_t.dim_size(1) == 8),
errors::InvalidArgument(
"Input transform should be num_images x 8 or 1 x 8"));
int32 out_height, out_width;
// Kernel is shared by legacy "ImageProjectiveTransform" op with 2 args.
if (ctx->num_inputs() >= 3) {
const Tensor& shape_t = ctx->input(2);
OP_REQUIRES(ctx, shape_t.dims() == 1,
errors::InvalidArgument("output shape must be 1-dimensional",
shape_t.shape().DebugString()));
OP_REQUIRES(ctx, shape_t.NumElements() == 2,
errors::InvalidArgument("output shape must have two elements",
shape_t.shape().DebugString()));
auto shape_vec = shape_t.vec<int32>();
out_height = shape_vec(0);
out_width = shape_vec(1);
OP_REQUIRES(ctx, out_height > 0 && out_width > 0,
errors::InvalidArgument("output dimensions must be positive"));
} else {
// Shape is N (batch size), H (height), W (width), C (channels).
out_height = images_t.shape().dim_size(1);
out_width = images_t.shape().dim_size(2);
}
T fill_value(0);
// Kernel is shared by "ImageProjectiveTransformV2" with 3 args.
if (ctx->num_inputs() >= 4) {
const Tensor& fill_value_t = ctx->input(3);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(fill_value_t.shape()),
errors::InvalidArgument("fill_value must be a scalar",
fill_value_t.shape().DebugString()));
fill_value = static_cast<T>(*(fill_value_t.scalar<float>().data()));
}
Tensor* output_t;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0,
TensorShape({images_t.dim_size(0), out_height,
out_width, images_t.dim_size(3)}),
&output_t));
auto output = output_t->tensor<T, 4>();
auto images = images_t.tensor<T, 4>();
auto transform = transform_t.matrix<float>();
(FillProjectiveTransform<Device, T>(interpolation))(
ctx->eigen_device<Device>(), &output, images, transform, fill_mode,
fill_value);
}
template <typename Device, typename T> template <typename Device, typename T>
class ImageProjectiveTransformV2 : public OpKernel { class ImageProjectiveTransformV2 : public OpKernel {
private: private:
@ -84,52 +146,7 @@ class ImageProjectiveTransformV2 : public OpKernel {
} }
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0); DoImageProjectiveTransformOp<Device, T>(ctx, interpolation_, fill_mode_);
const Tensor& transform_t = ctx->input(1);
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
errors::InvalidArgument("Input images must have rank 4"));
OP_REQUIRES(ctx,
(TensorShapeUtils::IsMatrix(transform_t.shape()) &&
(transform_t.dim_size(0) == images_t.dim_size(0) ||
transform_t.dim_size(0) == 1) &&
transform_t.dim_size(1) == 8),
errors::InvalidArgument(
"Input transform should be num_images x 8 or 1 x 8"));
int32 out_height, out_width;
// Kernel is shared by legacy "ImageProjectiveTransform" op with 2 args.
if (ctx->num_inputs() >= 3) {
const Tensor& shape_t = ctx->input(2);
OP_REQUIRES(ctx, shape_t.dims() == 1,
errors::InvalidArgument("output shape must be 1-dimensional",
shape_t.shape().DebugString()));
OP_REQUIRES(ctx, shape_t.NumElements() == 2,
errors::InvalidArgument("output shape must have two elements",
shape_t.shape().DebugString()));
auto shape_vec = shape_t.vec<int32>();
out_height = shape_vec(0);
out_width = shape_vec(1);
OP_REQUIRES(
ctx, out_height > 0 && out_width > 0,
errors::InvalidArgument("output dimensions must be positive"));
} else {
// Shape is N (batch size), H (height), W (width), C (channels).
out_height = images_t.shape().dim_size(1);
out_width = images_t.shape().dim_size(2);
}
Tensor* output_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(
0,
TensorShape({images_t.dim_size(0), out_height,
out_width, images_t.dim_size(3)}),
&output_t));
auto output = output_t->tensor<T, 4>();
auto images = images_t.tensor<T, 4>();
auto transform = transform_t.matrix<float>();
(FillProjectiveTransform<Device, T>(interpolation_))(
ctx->eigen_device<Device>(), &output, images, transform, fill_mode_);
} }
}; };
@ -148,6 +165,29 @@ TF_CALL_double(REGISTER);
#undef REGISTER #undef REGISTER
template <typename Device, typename T>
class ImageProjectiveTransformV3
: public ImageProjectiveTransformV2<Device, T> {
public:
explicit ImageProjectiveTransformV3(OpKernelConstruction* ctx)
: ImageProjectiveTransformV2<Device, T>(ctx) {}
};
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV3") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("dtype"), \
ImageProjectiveTransformV3<CPUDevice, TYPE>)
TF_CALL_uint8(REGISTER);
TF_CALL_int32(REGISTER);
TF_CALL_int64(REGISTER);
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
#undef REGISTER
#if GOOGLE_CUDA #if GOOGLE_CUDA
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;
@ -161,7 +201,8 @@ namespace functor {
template <> \ template <> \
void FillProjectiveTransform<GPUDevice, TYPE>::operator()( \ void FillProjectiveTransform<GPUDevice, TYPE>::operator()( \
const GPUDevice& device, OutputType* output, const InputType& images, \ const GPUDevice& device, OutputType* output, const InputType& images, \
const TransformsType& transform, const Mode fill_mode) const; \ const TransformsType& transform, const Mode fill_mode, \
const TYPE fill_value) const; \
extern template struct FillProjectiveTransform<GPUDevice, TYPE> extern template struct FillProjectiveTransform<GPUDevice, TYPE>
TF_CALL_uint8(DECLARE_PROJECT_FUNCTOR); TF_CALL_uint8(DECLARE_PROJECT_FUNCTOR);
@ -204,6 +245,23 @@ TF_CALL_double(REGISTER);
#undef REGISTER #undef REGISTER
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV3") \
.Device(DEVICE_GPU) \
.TypeConstraint<TYPE>("dtype") \
.HostMemory("output_shape") \
.HostMemory("fill_value"), \
ImageProjectiveTransformV3<GPUDevice, TYPE>)
TF_CALL_uint8(REGISTER);
TF_CALL_int32(REGISTER);
TF_CALL_int64(REGISTER);
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
#undef REGISTER
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -107,17 +107,20 @@ class ProjectiveGenerator {
typename TTypes<T, 4>::ConstTensor input_; typename TTypes<T, 4>::ConstTensor input_;
typename TTypes<float>::ConstMatrix transforms_; typename TTypes<float>::ConstMatrix transforms_;
const Interpolation interpolation_; const Interpolation interpolation_;
const T fill_value_;
public: public:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input, ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
typename TTypes<float>::ConstMatrix transforms, typename TTypes<float>::ConstMatrix transforms,
const Interpolation interpolation) const Interpolation interpolation, const T fill_value)
: input_(input), transforms_(transforms), interpolation_(interpolation) {} : input_(input),
transforms_(transforms),
interpolation_(interpolation),
fill_value_(fill_value) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const array<DenseIndex, 4>& coords) const { operator()(const array<DenseIndex, 4>& coords) const {
const T fill_value = T(0);
const int64 output_y = coords[1]; const int64 output_y = coords[1];
const int64 output_x = coords[2]; const int64 output_x = coords[2];
const float* transform = const float* transform =
@ -126,9 +129,9 @@ class ProjectiveGenerator {
: &transforms_.data()[transforms_.dimension(1) * coords[0]]; : &transforms_.data()[transforms_.dimension(1) * coords[0]];
float projection = transform[6] * output_x + transform[7] * output_y + 1.f; float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
if (projection == 0) { if (projection == 0) {
// Return the fill value (0) for infinite coordinates, // Return the fill value for infinite coordinates,
// which are outside the input image // which are outside the input image
return fill_value; return fill_value_;
} }
const float input_x = const float input_x =
(transform[0] * output_x + transform[1] * output_y + transform[2]) / (transform[0] * output_x + transform[1] * output_y + transform[2]) /
@ -146,13 +149,13 @@ class ProjectiveGenerator {
const DenseIndex channels = coords[3]; const DenseIndex channels = coords[3];
switch (interpolation_) { switch (interpolation_) {
case NEAREST: case NEAREST:
return nearest_interpolation(batch, y, x, channels, fill_value); return nearest_interpolation(batch, y, x, channels, fill_value_);
case BILINEAR: case BILINEAR:
return bilinear_interpolation(batch, y, x, channels, fill_value); return bilinear_interpolation(batch, y, x, channels, fill_value_);
} }
// Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST // Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST
// or INTERPOLATION_BILINEAR. // or INTERPOLATION_BILINEAR.
return fill_value; return fill_value_;
} }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
@ -225,27 +228,27 @@ struct FillProjectiveTransform {
EIGEN_ALWAYS_INLINE EIGEN_ALWAYS_INLINE
void operator()(const Device& device, OutputType* output, void operator()(const Device& device, OutputType* output,
const InputType& images, const TransformsType& transform, const InputType& images, const TransformsType& transform,
const Mode fill_mode) const { const Mode fill_mode, const T fill_value) const {
switch (fill_mode) { switch (fill_mode) {
case Mode::FILL_REFLECT: case Mode::FILL_REFLECT:
output->device(device) = output->device(device) =
output->generate(ProjectiveGenerator<Device, T, Mode::FILL_REFLECT>( output->generate(ProjectiveGenerator<Device, T, Mode::FILL_REFLECT>(
images, transform, interpolation)); images, transform, interpolation, fill_value));
break; break;
case Mode::FILL_WRAP: case Mode::FILL_WRAP:
output->device(device) = output->device(device) =
output->generate(ProjectiveGenerator<Device, T, Mode::FILL_WRAP>( output->generate(ProjectiveGenerator<Device, T, Mode::FILL_WRAP>(
images, transform, interpolation)); images, transform, interpolation, fill_value));
break; break;
case Mode::FILL_CONSTANT: case Mode::FILL_CONSTANT:
output->device(device) = output->generate( output->device(device) = output->generate(
ProjectiveGenerator<Device, T, Mode::FILL_CONSTANT>( ProjectiveGenerator<Device, T, Mode::FILL_CONSTANT>(
images, transform, interpolation)); images, transform, interpolation, fill_value));
break; break;
case Mode::FILL_NEAREST: case Mode::FILL_NEAREST:
output->device(device) = output->device(device) =
output->generate(ProjectiveGenerator<Device, T, Mode::FILL_NEAREST>( output->generate(ProjectiveGenerator<Device, T, Mode::FILL_NEAREST>(
images, transform, interpolation)); images, transform, interpolation, fill_value));
break; break;
} }
} }

View File

@ -1146,8 +1146,9 @@ REGISTER_OP("GenerateBoundingBoxProposals")
return Status::OK(); return Status::OK();
}); });
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0). // V3 op supports fill_value.
// V2 op supports output_shape. V1 op is in contrib. // V2 op supports output_shape.
// V1 op is in contrib.
REGISTER_OP("ImageProjectiveTransformV2") REGISTER_OP("ImageProjectiveTransformV2")
.Input("images: dtype") .Input("images: dtype")
.Input("transforms: float32") .Input("transforms: float32")
@ -1163,4 +1164,20 @@ REGISTER_OP("ImageProjectiveTransformV2")
c->Dim(input, 3)); c->Dim(input, 3));
}); });
REGISTER_OP("ImageProjectiveTransformV3")
.Input("images: dtype")
.Input("transforms: float32")
.Input("output_shape: int32")
.Input("fill_value: float32")
.Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
.Attr("interpolation: string")
.Attr("fill_mode: string = 'CONSTANT'")
.Output("transformed_images: dtype")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
c->Dim(input, 3));
});
} // namespace tensorflow } // namespace tensorflow

View File

@ -19102,6 +19102,54 @@ op {
} }
} }
} }
op {
name: "ImageProjectiveTransformV3"
input_arg {
name: "images"
type_attr: "dtype"
}
input_arg {
name: "transforms"
type: DT_FLOAT
}
input_arg {
name: "output_shape"
type: DT_INT32
}
input_arg {
name: "fill_value"
type: DT_FLOAT
}
output_arg {
name: "transformed_images"
type_attr: "dtype"
}
attr {
name: "dtype"
type: "type"
allowed_values {
list {
type: DT_UINT8
type: DT_INT32
type: DT_INT64
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "interpolation"
type: "string"
}
attr {
name: "fill_mode"
type: "string"
default_value {
s: "CONSTANT"
}
}
}
op { op {
name: "ImageSummary" name: "ImageSummary"
input_arg { input_arg {

View File

@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices( absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
const tensorflow::string &op_name) { const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 349> a = {{ static std::array<OpIndexInfo, 350> a = {{
{"Acosh"}, {"Acosh"},
{"AllToAll", 1, {0}}, {"AllToAll", 1, {0}},
{"ApproximateEqual"}, {"ApproximateEqual"},
@ -152,6 +152,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
{"IdentityReader"}, {"IdentityReader"},
{"Imag"}, {"Imag"},
{"ImageProjectiveTransformV2", 1, {2}}, {"ImageProjectiveTransformV2", 1, {2}},
{"ImageProjectiveTransformV3", 2, {2, 3}},
{"ImageSummary"}, {"ImageSummary"},
{"InitializeTable"}, {"InitializeTable"},
{"InitializeTableFromTextFile"}, {"InitializeTableFromTextFile"},
@ -412,7 +413,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices( absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
const tensorflow::string &op_name) { const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 465> a = {{ static std::array<OpIndexInfo, 466> a = {{
{"Abs"}, {"Abs"},
{"AccumulateNV2"}, {"AccumulateNV2"},
{"Acos"}, {"Acos"},
@ -568,6 +569,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
{"Igammac"}, {"Igammac"},
{"Imag"}, {"Imag"},
{"ImageProjectiveTransformV2"}, {"ImageProjectiveTransformV2"},
{"ImageProjectiveTransformV3"},
{"ImageSummary"}, {"ImageSummary"},
{"InitializeTable"}, {"InitializeTable"},
{"InitializeTableFromTextFile"}, {"InitializeTableFromTextFile"},

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
@ -466,6 +467,8 @@ class RandomTranslation(PreprocessingLayer):
The input is extended by wrapping around to the opposite edge. The input is extended by wrapping around to the opposite edge.
- *nearest*: `(a a a a | a b c d | d d d d)` - *nearest*: `(a a a a | a b c d | d d d d)`
The input is extended by the nearest pixel. The input is extended by the nearest pixel.
fill_value: a float represents the value to be filled outside the
boundaries when `fill_mode` is "constant".
interpolation: Interpolation mode. Supported values: "nearest", "bilinear". interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
seed: Integer. Used to create a random seed. seed: Integer. Used to create a random seed.
name: A string, the name of the layer. name: A string, the name of the layer.
@ -487,6 +490,7 @@ class RandomTranslation(PreprocessingLayer):
height_factor, height_factor,
width_factor, width_factor,
fill_mode='reflect', fill_mode='reflect',
fill_value=0.0,
interpolation='bilinear', interpolation='bilinear',
seed=None, seed=None,
name=None, name=None,
@ -522,6 +526,7 @@ class RandomTranslation(PreprocessingLayer):
check_fill_mode_and_interpolation(fill_mode, interpolation) check_fill_mode_and_interpolation(fill_mode, interpolation)
self.fill_mode = fill_mode self.fill_mode = fill_mode
self.fill_value = fill_value
self.interpolation = interpolation self.interpolation = interpolation
self.seed = seed self.seed = seed
self._rng = make_generator(self.seed) self._rng = make_generator(self.seed)
@ -559,7 +564,8 @@ class RandomTranslation(PreprocessingLayer):
inputs, inputs,
get_translation_matrix(translations), get_translation_matrix(translations),
interpolation=self.interpolation, interpolation=self.interpolation,
fill_mode=self.fill_mode) fill_mode=self.fill_mode,
fill_value=self.fill_value)
output = control_flow_util.smart_cond(training, random_translated_inputs, output = control_flow_util.smart_cond(training, random_translated_inputs,
lambda: inputs) lambda: inputs)
@ -574,6 +580,7 @@ class RandomTranslation(PreprocessingLayer):
'height_factor': self.height_factor, 'height_factor': self.height_factor,
'width_factor': self.width_factor, 'width_factor': self.width_factor,
'fill_mode': self.fill_mode, 'fill_mode': self.fill_mode,
'fill_value': self.fill_value,
'interpolation': self.interpolation, 'interpolation': self.interpolation,
'seed': self.seed, 'seed': self.seed,
} }
@ -617,6 +624,7 @@ def get_translation_matrix(translations, name=None):
def transform(images, def transform(images,
transforms, transforms,
fill_mode='reflect', fill_mode='reflect',
fill_value=0.0,
interpolation='bilinear', interpolation='bilinear',
output_shape=None, output_shape=None,
name=None): name=None):
@ -636,6 +644,8 @@ def transform(images,
not backpropagated into transformation parameters. not backpropagated into transformation parameters.
fill_mode: Points outside the boundaries of the input are filled according fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
fill_value: a float represents the value to be filled outside the
boundaries when `fill_mode` is "constant".
interpolation: Interpolation mode. Supported values: "nearest", "bilinear". interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
output_shape: Output dimesion after the transform, [height, width]. If None, output_shape: Output dimesion after the transform, [height, width]. If None,
output is the same size as input image. output is the same size as input image.
@ -689,6 +699,18 @@ def transform(images,
'new_height, new_width, instead got ' 'new_height, new_width, instead got '
'{}'.format(output_shape)) '{}'.format(output_shape))
fill_value = ops.convert_to_tensor_v2(
fill_value, dtypes.float32, name='fill_value')
if compat.forward_compatible(2020, 8, 5):
return gen_image_ops.ImageProjectiveTransformV3(
images=images,
output_shape=output_shape,
fill_value=fill_value,
transforms=transforms,
fill_mode=fill_mode.upper(),
interpolation=interpolation.upper())
return gen_image_ops.ImageProjectiveTransformV2( return gen_image_ops.ImageProjectiveTransformV2(
images=images, images=images,
output_shape=output_shape, output_shape=output_shape,
@ -774,6 +796,8 @@ class RandomRotation(PreprocessingLayer):
The input is extended by wrapping around to the opposite edge. The input is extended by wrapping around to the opposite edge.
- *nearest*: `(a a a a | a b c d | d d d d)` - *nearest*: `(a a a a | a b c d | d d d d)`
The input is extended by the nearest pixel. The input is extended by the nearest pixel.
fill_value: a float represents the value to be filled outside the
boundaries when `fill_mode` is "constant".
interpolation: Interpolation mode. Supported values: "nearest", "bilinear". interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
seed: Integer. Used to create a random seed. seed: Integer. Used to create a random seed.
name: A string, the name of the layer. name: A string, the name of the layer.
@ -793,6 +817,7 @@ class RandomRotation(PreprocessingLayer):
def __init__(self, def __init__(self,
factor, factor,
fill_mode='reflect', fill_mode='reflect',
fill_value=0.0,
interpolation='bilinear', interpolation='bilinear',
seed=None, seed=None,
name=None, name=None,
@ -809,6 +834,7 @@ class RandomRotation(PreprocessingLayer):
'got {}'.format(factor)) 'got {}'.format(factor))
check_fill_mode_and_interpolation(fill_mode, interpolation) check_fill_mode_and_interpolation(fill_mode, interpolation)
self.fill_mode = fill_mode self.fill_mode = fill_mode
self.fill_value = fill_value
self.interpolation = interpolation self.interpolation = interpolation
self.seed = seed self.seed = seed
self._rng = make_generator(self.seed) self._rng = make_generator(self.seed)
@ -834,6 +860,7 @@ class RandomRotation(PreprocessingLayer):
inputs, inputs,
get_rotation_matrix(angles, img_hd, img_wd), get_rotation_matrix(angles, img_hd, img_wd),
fill_mode=self.fill_mode, fill_mode=self.fill_mode,
fill_value=self.fill_value,
interpolation=self.interpolation) interpolation=self.interpolation)
output = control_flow_util.smart_cond(training, random_rotated_inputs, output = control_flow_util.smart_cond(training, random_rotated_inputs,
@ -848,6 +875,7 @@ class RandomRotation(PreprocessingLayer):
config = { config = {
'factor': self.factor, 'factor': self.factor,
'fill_mode': self.fill_mode, 'fill_mode': self.fill_mode,
'fill_value': self.fill_value,
'interpolation': self.interpolation, 'interpolation': self.interpolation,
'seed': self.seed, 'seed': self.seed,
} }
@ -889,6 +917,8 @@ class RandomZoom(PreprocessingLayer):
The input is extended by wrapping around to the opposite edge. The input is extended by wrapping around to the opposite edge.
- *nearest*: `(a a a a | a b c d | d d d d)` - *nearest*: `(a a a a | a b c d | d d d d)`
The input is extended by the nearest pixel. The input is extended by the nearest pixel.
fill_value: a float represents the value to be filled outside the
boundaries when `fill_mode` is "constant".
interpolation: Interpolation mode. Supported values: "nearest", "bilinear". interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
seed: Integer. Used to create a random seed. seed: Integer. Used to create a random seed.
name: A string, the name of the layer. name: A string, the name of the layer.
@ -914,11 +944,11 @@ class RandomZoom(PreprocessingLayer):
negative. negative.
""" """
# TODO(b/156526279): Add `fill_value` argument.
def __init__(self, def __init__(self,
height_factor, height_factor,
width_factor=None, width_factor=None,
fill_mode='reflect', fill_mode='reflect',
fill_value=0.0,
interpolation='bilinear', interpolation='bilinear',
seed=None, seed=None,
name=None, name=None,
@ -951,6 +981,7 @@ class RandomZoom(PreprocessingLayer):
check_fill_mode_and_interpolation(fill_mode, interpolation) check_fill_mode_and_interpolation(fill_mode, interpolation)
self.fill_mode = fill_mode self.fill_mode = fill_mode
self.fill_value = fill_value
self.interpolation = interpolation self.interpolation = interpolation
self.seed = seed self.seed = seed
self._rng = make_generator(self.seed) self._rng = make_generator(self.seed)
@ -985,6 +1016,7 @@ class RandomZoom(PreprocessingLayer):
return transform( return transform(
inputs, get_zoom_matrix(zooms, img_hd, img_wd), inputs, get_zoom_matrix(zooms, img_hd, img_wd),
fill_mode=self.fill_mode, fill_mode=self.fill_mode,
fill_value=self.fill_value,
interpolation=self.interpolation) interpolation=self.interpolation)
output = control_flow_util.smart_cond(training, random_zoomed_inputs, output = control_flow_util.smart_cond(training, random_zoomed_inputs,
@ -1000,6 +1032,7 @@ class RandomZoom(PreprocessingLayer):
'height_factor': self.height_factor, 'height_factor': self.height_factor,
'width_factor': self.width_factor, 'width_factor': self.width_factor,
'fill_mode': self.fill_mode, 'fill_mode': self.fill_mode,
'fill_value': self.fill_value,
'interpolation': self.interpolation, 'interpolation': self.interpolation,
'seed': self.seed, 'seed': self.seed,
} }

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.framework import test_util as tf_test_util
@ -698,11 +699,13 @@ class RandomTransformTest(keras_parameterized.TestCase):
transform_matrix, transform_matrix,
expected_output, expected_output,
mode, mode,
fill_value=0.0,
interpolation='bilinear'): interpolation='bilinear'):
inp = np.arange(15).reshape((1, 5, 3, 1)).astype(np.float32) inp = np.arange(15).reshape((1, 5, 3, 1)).astype(np.float32)
with self.cached_session(use_gpu=True): with self.cached_session(use_gpu=True):
output = image_preprocessing.transform( output = image_preprocessing.transform(
inp, transform_matrix, fill_mode=mode, interpolation=interpolation) inp, transform_matrix, fill_mode=mode,
fill_value=fill_value, interpolation=interpolation)
self.assertAllClose(expected_output, output) self.assertAllClose(expected_output, output)
def test_random_translation_reflect(self): def test_random_translation_reflect(self):
@ -871,7 +874,7 @@ class RandomTransformTest(keras_parameterized.TestCase):
self._run_random_transform_with_mock(transform_matrix, expected_output, self._run_random_transform_with_mock(transform_matrix, expected_output,
'nearest') 'nearest')
def test_random_translation_constant(self): def test_random_translation_constant_0(self):
# constant output is (0000|abcd|0000) # constant output is (0000|abcd|0000)
# Test down shift by 1. # Test down shift by 1.
@ -926,6 +929,62 @@ class RandomTransformTest(keras_parameterized.TestCase):
self._run_random_transform_with_mock(transform_matrix, expected_output, self._run_random_transform_with_mock(transform_matrix, expected_output,
'constant') 'constant')
def test_random_translation_constant_1(self):
with compat.forward_compatibility_horizon(2020, 8, 6):
# constant output is (1111|abcd|1111)
# Test down shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[1., 1., 1.],
[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8],
[9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'constant', fill_value=1.0)
# Test up shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[3., 4., 5.],
[6., 7., 8],
[9., 10., 11.],
[12., 13., 14.],
[1., 1., 1.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'constant', fill_value=1.0)
# Test left shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[1., 2., 1.],
[4., 5., 1.],
[7., 8., 1.],
[10., 11., 1.],
[13., 14., 1.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'constant', fill_value=1.0)
# Test right shift by 1.
# pyformat: disable
expected_output = np.asarray(
[[1., 0., 1.],
[1., 3., 4],
[1., 6., 7.],
[1., 9., 10.],
[1., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32)
# pyformat: enable
transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]])
self._run_random_transform_with_mock(transform_matrix, expected_output,
'constant', fill_value=1.0)
def test_random_translation_nearest_interpolation(self): def test_random_translation_nearest_interpolation(self):
# nearest output is (aaaa|abcd|dddd) # nearest output is (aaaa|abcd|dddd)

View File

@ -271,3 +271,38 @@ def _image_projective_transform_grad(op, grad):
interpolation=interpolation, interpolation=interpolation,
fill_mode=fill_mode) fill_mode=fill_mode)
return [output, None, None] return [output, None, None]
@ops.RegisterGradient("ImageProjectiveTransformV3")
def _image_projective_transform_v3_grad(op, grad):
"""Computes the gradient for ImageProjectiveTransform."""
images = op.inputs[0]
transforms = op.inputs[1]
interpolation = op.get_attr("interpolation")
fill_mode = op.get_attr("fill_mode")
image_or_images = ops.convert_to_tensor(images, name="images")
transform_or_transforms = ops.convert_to_tensor(
transforms, name="transforms", dtype=dtypes.float32)
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif len(transform_or_transforms.get_shape()) == 2:
transforms = transform_or_transforms
else:
raise TypeError("Transforms should have rank 1 or 2.")
# Invert transformations
transforms = flat_transforms_to_matrices(transforms=transforms)
inverse = linalg_ops.matrix_inverse(transforms)
transforms = matrices_to_flat_transforms(inverse)
output = gen_image_ops.image_projective_transform_v3(
images=grad,
transforms=transforms,
output_shape=array_ops.shape(image_or_images)[1:3],
interpolation=interpolation,
fill_mode=fill_mode,
fill_value=0.0)
return [output, None, None, None]

View File

@ -118,7 +118,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], " argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "adapt" name: "adapt"

View File

@ -118,7 +118,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], " argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "adapt" name: "adapt"

View File

@ -118,7 +118,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'reflect\', \'bilinear\', \'None\', \'None\'], " argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "adapt" name: "adapt"

View File

@ -1880,6 +1880,10 @@ tf_module {
name: "ImageProjectiveTransformV2" name: "ImageProjectiveTransformV2"
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], " argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
} }
member_method {
name: "ImageProjectiveTransformV3"
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'fill_value\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
}
member_method { member_method {
name: "ImageSummary" name: "ImageSummary"
argspec: "args=[\'tag\', \'tensor\', \'max_images\', \'bad_color\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'dtype: DT_UINT8\\ntensor_shape {\\n dim {\\n size: 4\\n }\\n}\\nint_val: 255\\nint_val: 0\\nint_val: 0\\nint_val: 255\\n\', \'None\'], " argspec: "args=[\'tag\', \'tensor\', \'max_images\', \'bad_color\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'dtype: DT_UINT8\\ntensor_shape {\\n dim {\\n size: 4\\n }\\n}\\nint_val: 255\\nint_val: 0\\nint_val: 0\\nint_val: 255\\n\', \'None\'], "

View File

@ -118,7 +118,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], " argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "adapt" name: "adapt"

View File

@ -118,7 +118,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], " argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "adapt" name: "adapt"

View File

@ -118,7 +118,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'reflect\', \'bilinear\', \'None\', \'None\'], " argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "adapt" name: "adapt"

View File

@ -1880,6 +1880,10 @@ tf_module {
name: "ImageProjectiveTransformV2" name: "ImageProjectiveTransformV2"
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], " argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
} }
member_method {
name: "ImageProjectiveTransformV3"
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'fill_value\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
}
member_method { member_method {
name: "ImageSummary" name: "ImageSummary"
argspec: "args=[\'tag\', \'tensor\', \'max_images\', \'bad_color\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'dtype: DT_UINT8\\ntensor_shape {\\n dim {\\n size: 4\\n }\\n}\\nint_val: 255\\nint_val: 0\\nint_val: 0\\nint_val: 255\\n\', \'None\'], " argspec: "args=[\'tag\', \'tensor\', \'max_images\', \'bad_color\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'dtype: DT_UINT8\\ntensor_shape {\\n dim {\\n size: 4\\n }\\n}\\nint_val: 255\\nint_val: 0\\nint_val: 0\\nint_val: 255\\n\', \'None\'], "