diff --git a/tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt index 73d548b226d..a9d5b981576 100644 --- a/tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt @@ -38,6 +38,12 @@ END name: "interpolation" description: < -class ImageProjectiveTransform : public OpKernel { +class ImageProjectiveTransformV2 : public OpKernel { private: Interpolation interpolation_; + Mode fill_mode_; public: - explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) { + explicit ImageProjectiveTransformV2(OpKernelConstruction* ctx) + : OpKernel(ctx) { string interpolation_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str)); if (interpolation_str == "NEAREST") { - interpolation_ = INTERPOLATION_NEAREST; + interpolation_ = Interpolation::NEAREST; } else if (interpolation_str == "BILINEAR") { - interpolation_ = INTERPOLATION_BILINEAR; + interpolation_ = Interpolation::BILINEAR; } else { LOG(ERROR) << "Invalid interpolation " << interpolation_str << ". Supported types: NEAREST, BILINEAR"; } + string mode_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("fill_mode", &mode_str)); + if (mode_str == "REFLECT") { + fill_mode_ = Mode::REFLECT; + } else if (mode_str == "WRAP") { + fill_mode_ = Mode::WRAP; + } else if (mode_str == "CONSTANT") { + fill_mode_ = Mode::CONSTANT; + } else { + LOG(ERROR) << "Invalid mode " << mode_str + << ". Supported types: REFLECT, WRAP, CONSTANT"; + } } void Compute(OpKernelContext* ctx) override { @@ -78,8 +90,7 @@ class ImageProjectiveTransform : public OpKernel { (TensorShapeUtils::IsMatrix(transform_t.shape()) && (transform_t.dim_size(0) == images_t.dim_size(0) || transform_t.dim_size(0) == 1) && - transform_t.dim_size(1) == - ProjectiveGenerator::kNumParameters), + transform_t.dim_size(1) == 8), errors::InvalidArgument( "Input transform should be num_images x 8 or 1 x 8")); @@ -116,15 +127,15 @@ class ImageProjectiveTransform : public OpKernel { auto transform = transform_t.matrix(); (FillProjectiveTransform(interpolation_))( - ctx->eigen_device(), &output, images, transform); + ctx->eigen_device(), &output, images, transform, fill_mode_); } }; -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("dtype"), \ - ImageProjectiveTransform) +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + ImageProjectiveTransformV2) TF_CALL_uint8(REGISTER); TF_CALL_int32(REGISTER); @@ -138,33 +149,48 @@ TF_CALL_double(REGISTER); #if GOOGLE_CUDA typedef Eigen::GpuDevice GPUDevice; +typedef generator::Mode Mode; namespace functor { // NOTE(ringwalt): We get an undefined symbol error if we don't explicitly // instantiate the operator() in GCC'd code. -#define DECLARE_FUNCTOR(TYPE) \ +#define DECLARE_PROJECT_FUNCTOR(TYPE) \ template <> \ void FillProjectiveTransform::operator()( \ const GPUDevice& device, OutputType* output, const InputType& images, \ - const TransformsType& transform) const; \ + const TransformsType& transform, const Mode fill_mode) const; \ extern template struct FillProjectiveTransform -TF_CALL_uint8(DECLARE_FUNCTOR); -TF_CALL_int32(DECLARE_FUNCTOR); -TF_CALL_int64(DECLARE_FUNCTOR); -TF_CALL_half(DECLARE_FUNCTOR); -TF_CALL_float(DECLARE_FUNCTOR); -TF_CALL_double(DECLARE_FUNCTOR); +TF_CALL_uint8(DECLARE_PROJECT_FUNCTOR); +TF_CALL_int32(DECLARE_PROJECT_FUNCTOR); +TF_CALL_int64(DECLARE_PROJECT_FUNCTOR); +TF_CALL_half(DECLARE_PROJECT_FUNCTOR); +TF_CALL_float(DECLARE_PROJECT_FUNCTOR); +TF_CALL_double(DECLARE_PROJECT_FUNCTOR); } // end namespace functor -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("dtype") \ - .HostMemory("output_shape"), \ - ImageProjectiveTransform) +namespace generator { + +#define DECLARE_MAP_FUNCTOR(Mode) \ + template <> \ + float MapCoordinate::operator()(const float out_coord, \ + const DenseIndex len); \ + extern template struct MapCoordinate + +DECLARE_MAP_FUNCTOR(Mode::REFLECT); +DECLARE_MAP_FUNCTOR(Mode::WRAP); +DECLARE_MAP_FUNCTOR(Mode::CONSTANT); + +} // end namespace generator + +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("dtype") \ + .HostMemory("output_shape"), \ + ImageProjectiveTransformV2) TF_CALL_uint8(REGISTER); TF_CALL_int32(REGISTER); diff --git a/tensorflow/core/kernels/image_ops.h b/tensorflow/core/kernels/image_ops.h index 4e375a67184..300c65921bd 100644 --- a/tensorflow/core/kernels/image_ops.h +++ b/tensorflow/core/kernels/image_ops.h @@ -29,12 +29,67 @@ namespace tensorflow { namespace generator { -enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR }; +enum Interpolation { NEAREST, BILINEAR }; +enum Mode { REFLECT, WRAP, CONSTANT }; using Eigen::array; using Eigen::DenseIndex; -template +template +struct MapCoordinate { + float operator()(const float out_coord, const DenseIndex len); +}; + +template +struct MapCoordinate { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord, + const DenseIndex len) { + float in_coord = out_coord; + // Reflect [abcd] to [dcba|abcd|dcba], periodically from [0, 2 * len) + // over [abcddcba] + const DenseIndex boundary = 2 * len; + // Shift coordinate to (-boundary, boundary) + in_coord -= boundary * static_cast(in_coord / boundary); + // Convert negative coordinates from [-boundary, 0) to [0, boundary) + if (in_coord < 0) { + in_coord += boundary; + } + // Coordinate in_coord between [len, boundary) should reverse reflect + // to coordinate to (bounary - 1 - in_coord) between [0, len) + if (in_coord > len - 1) { + in_coord = boundary - 1 - in_coord; + } + return in_coord; + } +}; + +template +struct MapCoordinate { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord, + const DenseIndex len) { + float in_coord = out_coord; + // Wrap [abcd] to [abcd|abcd|abcd], periodically from [0, len) + // over [abcd] + const DenseIndex boundary = len; + // Shift coordinate to (-boundary, boundary) + in_coord -= boundary * static_cast(in_coord / boundary); + // Shift negative coordinate from [-boundary, 0) to [0, boundary) + if (in_coord < 0) { + in_coord += boundary; + } + return in_coord; + } +}; + +template +struct MapCoordinate { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord, + const DenseIndex len) { + return out_coord; + } +}; + +template class ProjectiveGenerator { private: typename TTypes::ConstTensor input_; @@ -42,8 +97,6 @@ class ProjectiveGenerator { const Interpolation interpolation_; public: - static const int kNumParameters = 8; - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ProjectiveGenerator(typename TTypes::ConstTensor input, typename TTypes::ConstMatrix transforms, @@ -52,6 +105,7 @@ class ProjectiveGenerator { EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T operator()(const array& coords) const { + const T fill_value = T(0); const int64 output_y = coords[1]; const int64 output_x = coords[2]; const float* transform = @@ -62,7 +116,7 @@ class ProjectiveGenerator { if (projection == 0) { // Return the fill value (0) for infinite coordinates, // which are outside the input image - return T(0); + return fill_value; } const float input_x = (transform[0] * output_x + transform[1] * output_y + transform[2]) / @@ -71,22 +125,24 @@ class ProjectiveGenerator { (transform[3] * output_x + transform[4] * output_y + transform[5]) / projection; - const T fill_value = T(0); + // Map out-of-boundary input coordinates to in-boundary based on fill_mode. + auto map_functor = MapCoordinate(); + const float x = map_functor(input_x, input_.dimension(2)); + const float y = map_functor(input_y, input_.dimension(1)); + + const DenseIndex batch = coords[0]; + const DenseIndex channels = coords[3]; switch (interpolation_) { - case INTERPOLATION_NEAREST: - // Switch the order of x and y again for indexing into the image. - return nearest_interpolation(coords[0], input_y, input_x, coords[3], - fill_value); - case INTERPOLATION_BILINEAR: - return bilinear_interpolation(coords[0], input_y, input_x, coords[3], - fill_value); + case NEAREST: + return nearest_interpolation(batch, y, x, channels, fill_value); + case BILINEAR: + return bilinear_interpolation(batch, y, x, channels, fill_value); } // Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST // or INTERPOLATION_BILINEAR. - return T(0); + return fill_value; } - private: EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T nearest_interpolation(const DenseIndex batch, const float y, const float x, const DenseIndex channel, const T fill_value) const { @@ -138,12 +194,10 @@ class ProjectiveGenerator { } // end namespace generator -// NOTE(ringwalt): We MUST wrap the generate() call in a functor and explicitly -// instantiate the functor in image_ops_gpu.cu.cc. Otherwise, we will be missing -// some Eigen device code. namespace functor { using generator::Interpolation; +using generator::Mode; using generator::ProjectiveGenerator; template @@ -151,17 +205,32 @@ struct FillProjectiveTransform { typedef typename TTypes::Tensor OutputType; typedef typename TTypes::ConstTensor InputType; typedef typename TTypes::ConstTensor TransformsType; - const Interpolation interpolation_; + const Interpolation interpolation; - FillProjectiveTransform(Interpolation interpolation) - : interpolation_(interpolation) {} + explicit FillProjectiveTransform(Interpolation interpolation) + : interpolation(interpolation) {} EIGEN_ALWAYS_INLINE void operator()(const Device& device, OutputType* output, - const InputType& images, - const TransformsType& transform) const { - output->device(device) = output->generate( - ProjectiveGenerator(images, transform, interpolation_)); + const InputType& images, const TransformsType& transform, + const Mode fill_mode) const { + switch (fill_mode) { + case Mode::REFLECT: + output->device(device) = + output->generate(ProjectiveGenerator( + images, transform, interpolation)); + break; + case Mode::WRAP: + output->device(device) = + output->generate(ProjectiveGenerator( + images, transform, interpolation)); + break; + case Mode::CONSTANT: + output->device(device) = + output->generate(ProjectiveGenerator( + images, transform, interpolation)); + break; + } } }; diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index c2d4eedb1c1..418f1e20e37 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -1028,7 +1028,6 @@ REGISTER_OP("GenerateBoundingBoxProposals") return Status::OK(); }); -// TODO(ringwalt): Add a "fill_mode" attr with "constant", "mirror", etc. // TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0). // V2 op supports output_shape. V1 op is in contrib. REGISTER_OP("ImageProjectiveTransformV2") @@ -1037,6 +1036,7 @@ REGISTER_OP("ImageProjectiveTransformV2") .Input("output_shape: int32") .Attr("dtype: {uint8, int32, int64, float16, float32, float64}") .Attr("interpolation: string") + .Attr("fill_mode: string = 'CONSTANT'") .Output("transformed_images: dtype") .SetShapeFn([](InferenceContext* c) { ShapeHandle input; diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py index 43d60cc6a9c..8a886ec2778 100644 --- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py +++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -418,9 +419,15 @@ class RandomTranslation(Layer): When represented as a single float, this value is used for both the upper and lower bound. fill_mode: Points outside the boundaries of the input are filled according - to the given mode (one of `{'nearest', 'bilinear'}`). - fill_value: Value used for points outside the boundaries of the input if - `mode='constant'`. + to the given mode (one of `{'constant', 'reflect', 'wrap'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + interpolation: Interpolation mode. Supported values: "nearest", "bilinear". seed: Integer. Used to create a random seed. name: A string, the name of the layer. @@ -440,8 +447,8 @@ class RandomTranslation(Layer): def __init__(self, height_factor, width_factor, - fill_mode='nearest', - fill_value=0., + fill_mode='reflect', + interpolation='bilinear', seed=None, name=None, **kwargs): @@ -471,11 +478,16 @@ class RandomTranslation(Layer): raise ValueError('`width_factor` must have values between [-1, 1], ' 'got {}'.format(width_factor)) - if fill_mode not in {'nearest', 'bilinear'}: + if fill_mode not in {'reflect', 'wrap', 'constant'}: raise NotImplementedError( - '`fill_mode` {} is not supported yet.'.format(fill_mode)) + 'Unknown `fill_mode` {}. Only `reflect`, `wrap` and ' + '`constant` are supported.'.format(fill_mode)) + if interpolation not in {'nearest', 'bilinear'}: + raise NotImplementedError( + 'Unknown `interpolation` {}. Only `nearest` and ' + '`bilinear` are supported.'.format(interpolation)) self.fill_mode = fill_mode - self.fill_value = fill_value + self.interpolation = interpolation self.seed = seed self._rng = make_generator(self.seed) self.input_spec = InputSpec(ndim=4) @@ -508,7 +520,8 @@ class RandomTranslation(Layer): return transform( inputs, get_translation_matrix(translations), - interpolation=self.fill_mode) + interpolation=self.interpolation, + fill_mode=self.fill_mode) output = tf_utils.smart_cond(training, random_translated_inputs, lambda: inputs) @@ -523,7 +536,7 @@ class RandomTranslation(Layer): 'height_factor': self.height_factor, 'width_factor': self.width_factor, 'fill_mode': self.fill_mode, - 'fill_value': self.fill_value, + 'interpolation': self.interpolation, 'seed': self.seed, } base_config = super(RandomTranslation, self).get_config() @@ -565,7 +578,8 @@ def get_translation_matrix(translations, name=None): def transform(images, transforms, - interpolation='nearest', + fill_mode='reflect', + interpolation='bilinear', output_shape=None, name=None): """Applies the given transform(s) to the image(s). @@ -582,11 +596,33 @@ def transform(images, `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to the transform mapping input points to output points. Note that gradients are not backpropagated into transformation parameters. - interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap'}`). + interpolation: Interpolation mode. Supported values: "nearest", "bilinear". output_shape: Output dimesion after the transform, [height, width]. If None, output is the same size as input image. name: The name of the op. + ## Fill mode. + Behavior for each valid value is as follows: + + reflect (d c b a | a b c d | d c b a) + The input is extended by reflecting about the edge of the last pixel. + + constant (k k k k | a b c d | k k k k) + The input is extended by filling all values beyond the edge with the same + constant value k = 0. + + wrap (a b c d | a b c d | a b c d) + The input is extended by wrapping around to the opposite edge. + + Input shape: + 4D tensor with shape: `(samples, height, width, channels)`, + data_format='channels_last'. + Output shape: + 4D tensor with shape: `(samples, height, width, channels)`, + data_format='channels_last'. + Returns: Image(s) with the same type and shape as `images`, with the given transform(s) applied. Transformed coordinates outside of the input image @@ -612,6 +648,13 @@ def transform(images, 'new_height, new_width, instead got ' '{}'.format(output_shape)) + if compat.forward_compatible(2020, 3, 25): + return image_ops.image_projective_transform_v2( + images, + output_shape=output_shape, + transforms=transforms, + fill_mode=fill_mode.upper(), + interpolation=interpolation.upper()) return image_ops.image_projective_transform_v2( images, output_shape=output_shape, @@ -680,11 +723,24 @@ class RandomRotation(Layer): 2 representing lower and upper bound for rotating clockwise and counter-clockwise. When represented as a single float, lower = upper. fill_mode: Points outside the boundaries of the input are filled according - to the given mode (one of `{'constant', 'nearest', 'bilinear', 'reflect', - 'wrap'}`). + to the given mode (one of `{'constant', 'reflect', 'wrap'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + interpolation: Interpolation mode. Supported values: "nearest", "bilinear". seed: Integer. Used to create a random seed. name: A string, the name of the layer. + Input shape: + 4D tensor with shape: `(samples, height, width, channels)`, + data_format='channels_last'. + Output shape: + 4D tensor with shape: `(samples, height, width, channels)`, + data_format='channels_last'. + Raise: ValueError: if lower bound is not between [0, 1], or upper bound is negative. @@ -692,7 +748,8 @@ class RandomRotation(Layer): def __init__(self, factor, - fill_mode='nearest', + fill_mode='reflect', + interpolation='bilinear', seed=None, name=None, **kwargs): @@ -705,10 +762,16 @@ class RandomRotation(Layer): if self.lower < 0. or self.upper < 0.: raise ValueError('Factor cannot have negative values, ' 'got {}'.format(factor)) - if fill_mode not in {'nearest', 'bilinear'}: + if fill_mode not in {'reflect', 'wrap', 'constant'}: raise NotImplementedError( - '`fill_mode` {} is not supported yet.'.format(fill_mode)) + 'Unknown `fill_mode` {}. Only `reflect`, `wrap` and ' + '`constant` are supported.'.format(fill_mode)) + if interpolation not in {'nearest', 'bilinear'}: + raise NotImplementedError( + 'Unknown `interpolation` {}. Only `nearest` and ' + '`bilinear` are supported.'.format(interpolation)) self.fill_mode = fill_mode + self.interpolation = interpolation self.seed = seed self._rng = make_generator(self.seed) self.input_spec = InputSpec(ndim=4) @@ -732,7 +795,8 @@ class RandomRotation(Layer): return transform( inputs, get_rotation_matrix(angles, img_hd, img_wd), - interpolation=self.fill_mode) + fill_mode=self.fill_mode, + interpolation=self.interpolation) output = tf_utils.smart_cond(training, random_rotated_inputs, lambda: inputs) @@ -746,6 +810,7 @@ class RandomRotation(Layer): config = { 'factor': self.factor, 'fill_mode': self.fill_mode, + 'interpolation': self.interpolation, 'seed': self.seed, } base_config = super(RandomRotation, self).get_config() @@ -768,9 +833,14 @@ class RandomZoom(Layer): upper and lower bound. For instance, `width_factor=(0.2, 0.3)` result in an output zoom varying in the range `[original * 20%, original * 30%]`. fill_mode: Points outside the boundaries of the input are filled according - to the given mode (one of `{'nearest', 'bilinear'}`). - fill_value: Value used for points outside the boundaries of the input if - `mode='constant'`. + to the given mode (one of `{'constant', 'reflect', 'wrap'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + interpolation: Interpolation mode. Supported values: "nearest", "bilinear". seed: Integer. Used to create a random seed. name: A string, the name of the layer. @@ -790,8 +860,8 @@ class RandomZoom(Layer): def __init__(self, height_factor, width_factor, - fill_mode='nearest', - fill_value=0., + fill_mode='reflect', + interpolation='bilinear', seed=None, name=None, **kwargs): @@ -821,11 +891,16 @@ class RandomZoom(Layer): raise ValueError('`width_factor` cannot have lower bound larger than ' 'upper bound, got {}.'.format(width_factor)) - if fill_mode not in {'nearest', 'bilinear'}: + if fill_mode not in {'reflect', 'wrap', 'constant'}: raise NotImplementedError( - '`fill_mode` {} is not supported yet.'.format(fill_mode)) + 'Unknown `fill_mode` {}. Only `reflect`, `wrap` and ' + '`constant` are supported.'.format(fill_mode)) + if interpolation not in {'nearest', 'bilinear'}: + raise NotImplementedError( + 'Unknown `interpolation` {}. Only `nearest` and ' + '`bilinear` are supported.'.format(interpolation)) self.fill_mode = fill_mode - self.fill_value = fill_value + self.interpolation = interpolation self.seed = seed self._rng = make_generator(self.seed) self.input_spec = InputSpec(ndim=4) @@ -857,7 +932,8 @@ class RandomZoom(Layer): dtype=inputs.dtype) return transform( inputs, get_zoom_matrix(zooms, img_hd, img_wd), - interpolation=self.fill_mode) + fill_mode=self.fill_mode, + interpolation=self.interpolation) output = tf_utils.smart_cond(training, random_zoomed_inputs, lambda: inputs) @@ -872,7 +948,7 @@ class RandomZoom(Layer): 'height_factor': self.height_factor, 'width_factor': self.width_factor, 'fill_mode': self.fill_mode, - 'fill_value': self.fill_value, + 'interpolation': self.interpolation, 'seed': self.seed, } base_config = super(RandomZoom, self).get_config() diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py index 3297a9d4849..ff5c63efd5a 100644 --- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py +++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.compat import compat from tensorflow.python.framework import errors from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import keras_parameterized @@ -509,6 +510,250 @@ class RandomTranslationTest(keras_parameterized.TestCase): self.assertEqual(layer_1.name, layer.name) +@keras_parameterized.run_all_keras_modes(always_skip_v1=True) +class RandomTransformTest(keras_parameterized.TestCase): + + def _run_random_transform_with_mock(self, + transform_matrix, + expected_output, + mode, + interpolation='bilinear'): + inp = np.arange(15).reshape((1, 5, 3, 1)).astype(np.float32) + with self.cached_session(use_gpu=True): + output = image_preprocessing.transform( + inp, transform_matrix, fill_mode=mode, interpolation=interpolation) + self.assertAllClose(expected_output, output) + + def test_random_translation_reflect(self): + # reflected output is (dcba|abcd|dcba) + + if compat.forward_compatible(2020, 3, 25): + # Test down shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[0., 1., 2.], + [0., 1., 2.], + [3., 4., 5.], + [6., 7., 8], + [9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'reflect') + + # Test up shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[3., 4., 5.], + [6., 7., 8], + [9., 10., 11.], + [12., 13., 14.], + [12., 13., 14.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'reflect') + + # Test left shift by 1. + # reflected output is (dcba|abcd|dcba) + # pyformat: disable + expected_output = np.asarray( + [[1., 2., 2.], + [4., 5., 5.], + [7., 8., 8.], + [10., 11., 11.], + [13., 14., 14.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'reflect') + + # Test right shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[0., 0., 1.], + [3., 3., 4], + [6., 6., 7.], + [9., 9., 10.], + [12., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'reflect') + + def test_random_translation_wrap(self): + # warpped output is (abcd|abcd|abcd) + + if compat.forward_compatible(2020, 3, 25): + # Test down shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[12., 13., 14.], + [0., 1., 2.], + [3., 4., 5.], + [6., 7., 8], + [9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'wrap') + + # Test up shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[3., 4., 5.], + [6., 7., 8], + [9., 10., 11.], + [12., 13., 14.], + [0., 1., 2.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'wrap') + + # Test left shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[1., 2., 0.], + [4., 5., 3.], + [7., 8., 6.], + [10., 11., 9.], + [13., 14., 12.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'wrap') + + # Test right shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[2., 0., 1.], + [5., 3., 4], + [8., 6., 7.], + [11., 9., 10.], + [14., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'wrap') + + def test_random_translation_constant(self): + # constant output is (0000|abcd|0000) + + if compat.forward_compatible(2020, 3, 25): + # Test down shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[0., 0., 0.], + [0., 1., 2.], + [3., 4., 5.], + [6., 7., 8], + [9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'constant') + + # Test up shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[3., 4., 5.], + [6., 7., 8], + [9., 10., 11.], + [12., 13., 14.], + [0., 0., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'constant') + + # Test left shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[1., 2., 0.], + [4., 5., 0.], + [7., 8., 0.], + [10., 11., 0.], + [13., 14., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'constant') + + # Test right shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[0., 0., 1.], + [0., 3., 4], + [0., 6., 7.], + [0., 9., 10.], + [0., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]]) + self._run_random_transform_with_mock(transform_matrix, expected_output, + 'constant') + + def test_random_translation_nearest_interpolation(self): + # nearest output is (aaaa|abcd|dddd) + + if compat.forward_compatible(2020, 3, 25): + # Test down shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[0., 0., 0.], + [0., 1., 2.], + [3., 4., 5.], + [6., 7., 8], + [9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]]) + self._run_random_transform_with_mock( + transform_matrix, expected_output, + mode='constant', interpolation='nearest') + + # Test up shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[3., 4., 5.], + [6., 7., 8], + [9., 10., 11.], + [12., 13., 14.], + [0., 0., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]]) + self._run_random_transform_with_mock( + transform_matrix, expected_output, + mode='constant', interpolation='nearest') + + # Test left shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[1., 2., 0.], + [4., 5., 0.], + [7., 8., 0.], + [10., 11., 0.], + [13., 14., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]]) + self._run_random_transform_with_mock( + transform_matrix, expected_output, + mode='constant', interpolation='nearest') + + # Test right shift by 1. + # pyformat: disable + expected_output = np.asarray( + [[0., 0., 1.], + [0., 3., 4], + [0., 6., 7.], + [0., 9., 10.], + [0., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32) + # pyformat: enable + transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]]) + self._run_random_transform_with_mock( + transform_matrix, expected_output, + mode='constant', interpolation='nearest') + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) class RandomRotationTest(keras_parameterized.TestCase): diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index 8d4d35ffa95..3c8d4989a5f 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -245,6 +245,7 @@ def _image_projective_transform_grad(op, grad): images = op.inputs[0] transforms = op.inputs[1] interpolation = op.get_attr("interpolation") + fill_mode = op.get_attr("fill_mode") image_or_images = ops.convert_to_tensor(images, name="images") transform_or_transforms = ops.convert_to_tensor( @@ -267,5 +268,6 @@ def _image_projective_transform_grad(op, grad): images=grad, transforms=transforms, output_shape=array_ops.shape(image_or_images)[1:3], - interpolation=interpolation) + interpolation=interpolation, + fill_mode=fill_mode) return [output, None, None] diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt index f80dc852d65..8aeee741de8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt index f1697dfb854..1b0daace7bf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt index fa22e72f7e0..3fb7e6856c8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index d62a863d710..45c92b94119 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1778,7 +1778,7 @@ tf_module { } member_method { name: "ImageProjectiveTransformV2" - argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], " } member_method { name: "ImageSummary" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt index f80dc852d65..8aeee741de8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-rotation.pbtxt @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt index f1697dfb854..1b0daace7bf 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-translation.pbtxt @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt index fa22e72f7e0..3fb7e6856c8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-zoom.pbtxt @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index d62a863d710..45c92b94119 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1778,7 +1778,7 @@ tf_module { } member_method { name: "ImageProjectiveTransformV2" - argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], " } member_method { name: "ImageSummary"