From 26992815bf7ee136c49f658596a9fbde7881d425 Mon Sep 17 00:00:00 2001
From: Zhenyu Tan <tanzheny@google.com>
Date: Wed, 29 Jan 2020 18:55:24 -0800
Subject: [PATCH] Image Translation Keras layer for preprocessing.

PiperOrigin-RevId: 292262892
Change-Id: Iff5f81b22ed5d39cc4f5b7eca9fa55ca5936b9d8
---
 .../api_def_ImageProjectiveTransformV2.pbtxt  |  51 +++++
 tensorflow/core/kernels/BUILD                 |   7 +
 tensorflow/core/kernels/image_ops.cc          | 180 +++++++++++++++
 tensorflow/core/kernels/image_ops.h           | 172 ++++++++++++++
 tensorflow/core/kernels/image_ops_gpu.cu.cc   |  43 ++++
 tensorflow/core/ops/image_ops.cc              |  18 ++
 .../preprocessing/image_preprocessing.py      | 216 +++++++++++++++++-
 .../preprocessing/image_preprocessing_test.py |  54 +++++
 tensorflow/python/ops/image_ops.py            | 109 ++++++++-
 .../api/golden/v1/tensorflow.raw_ops.pbtxt    |   4 +
 .../api/golden/v2/tensorflow.raw_ops.pbtxt    |   4 +
 11 files changed, 856 insertions(+), 2 deletions(-)
 create mode 100644 tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt
 create mode 100644 tensorflow/core/kernels/image_ops.cc
 create mode 100644 tensorflow/core/kernels/image_ops.h
 create mode 100644 tensorflow/core/kernels/image_ops_gpu.cu.cc

diff --git a/tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt
new file mode 100644
index 00000000000..73d548b226d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ImageProjectiveTransformV2.pbtxt
@@ -0,0 +1,51 @@
+op {
+  graph_op_name: "ImageProjectiveTransformV2"
+  visibility: HIDDEN
+  in_arg {
+    name: "images"
+    description: <<END
+4-D with shape `[batch, height, width, channels]`.
+END
+  }
+  in_arg {
+    name: "transforms"
+    description: <<END
+2-D Tensor, `[batch, 8]` or `[1, 8]` matrix, where each row corresponds to a 3 x 3
+projective transformation matrix, with the last entry assumed to be 1. If there
+is one row, the same transformation will be applied to all images.
+END
+  }
+  in_arg {
+    name: "output_shape"
+    description: <<END
+1-D Tensor [new_height, new_width].
+END
+  }
+  out_arg {
+    name: "transformed_images"
+    description: <<END
+4-D with shape
+`[batch, new_height, new_width, channels]`.
+END
+  }
+  attr {
+    name: "dtype"
+    description: <<END
+Input dtype.
+END
+  }
+  attr {
+    name: "interpolation"
+    description: <<END
+Interpolation method, "NEAREST" or "BILINEAR".
+END
+  }
+  summary: "Applies the given transform to each of the images."
+  description: <<END
+If one row of `transforms` is `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps
+the *output* point `(x, y)` to a transformed *input* point
+`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where
+`k = c0 x + c1 y + 1`. If the transformed point lays outside of the input
+image, the output pixel is set to 0.
+END
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f8c61084dee..b70c1e5ae9e 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2924,6 +2924,7 @@ cc_library(
         ":encode_png_op",
         ":extract_jpeg_shape_op",
         ":generate_box_proposals_op",
+        ":image_ops",
         ":non_max_suppression_op",
         ":random_crop_op",
         ":resize_area_op",
@@ -3084,6 +3085,12 @@ tf_kernel_library(
     deps = IMAGE_DEPS,
 )
 
+tf_kernel_library(
+    name = "image_ops",
+    prefix = "image_ops",
+    deps = IMAGE_DEPS,
+)
+
 tf_kernel_library(
     name = "encode_wav_op",
     prefix = "encode_wav_op",
diff --git a/tensorflow/core/kernels/image_ops.cc b/tensorflow/core/kernels/image_ops.cc
new file mode 100644
index 00000000000..2e81cdaad72
--- /dev/null
+++ b/tensorflow/core/kernels/image_ops.cc
@@ -0,0 +1,180 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif  // GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/image_ops.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+// Explicit instantiation of the CPU functor.
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template struct FillProjectiveTransform<CPUDevice, uint8>;
+template struct FillProjectiveTransform<CPUDevice, int32>;
+template struct FillProjectiveTransform<CPUDevice, int64>;
+template struct FillProjectiveTransform<CPUDevice, Eigen::half>;
+template struct FillProjectiveTransform<CPUDevice, float>;
+template struct FillProjectiveTransform<CPUDevice, double>;
+
+}  // end namespace functor
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+using functor::FillProjectiveTransform;
+using generator::Interpolation;
+using generator::INTERPOLATION_BILINEAR;
+using generator::INTERPOLATION_NEAREST;
+using generator::ProjectiveGenerator;
+
+template <typename Device, typename T>
+class ImageProjectiveTransform : public OpKernel {
+ private:
+  Interpolation interpolation_;
+
+ public:
+  explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    string interpolation_str;
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
+    if (interpolation_str == "NEAREST") {
+      interpolation_ = INTERPOLATION_NEAREST;
+    } else if (interpolation_str == "BILINEAR") {
+      interpolation_ = INTERPOLATION_BILINEAR;
+    } else {
+      LOG(ERROR) << "Invalid interpolation " << interpolation_str
+                 << ". Supported types: NEAREST, BILINEAR";
+    }
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor& images_t = ctx->input(0);
+    const Tensor& transform_t = ctx->input(1);
+    OP_REQUIRES(ctx, images_t.shape().dims() == 4,
+                errors::InvalidArgument("Input images must have rank 4"));
+    OP_REQUIRES(ctx,
+                (TensorShapeUtils::IsMatrix(transform_t.shape()) &&
+                 (transform_t.dim_size(0) == images_t.dim_size(0) ||
+                  transform_t.dim_size(0) == 1) &&
+                 transform_t.dim_size(1) ==
+                     ProjectiveGenerator<Device, T>::kNumParameters),
+                errors::InvalidArgument(
+                    "Input transform should be num_images x 8 or 1 x 8"));
+
+    int32 out_height, out_width;
+    // Kernel is shared by legacy "ImageProjectiveTransform" op with 2 args.
+    if (ctx->num_inputs() >= 3) {
+      const Tensor& shape_t = ctx->input(2);
+      OP_REQUIRES(ctx, shape_t.dims() == 1,
+                  errors::InvalidArgument("output shape must be 1-dimensional",
+                                          shape_t.shape().DebugString()));
+      OP_REQUIRES(ctx, shape_t.NumElements() == 2,
+                  errors::InvalidArgument("output shape must have two elements",
+                                          shape_t.shape().DebugString()));
+      auto shape_vec = shape_t.vec<int32>();
+      out_height = shape_vec(0);
+      out_width = shape_vec(1);
+      OP_REQUIRES(
+          ctx, out_height > 0 && out_width > 0,
+          errors::InvalidArgument("output dimensions must be positive"));
+    } else {
+      // Shape is N (batch size), H (height), W (width), C (channels).
+      out_height = images_t.shape().dim_size(1);
+      out_width = images_t.shape().dim_size(2);
+    }
+
+    Tensor* output_t;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(
+                            0,
+                            TensorShape({images_t.dim_size(0), out_height,
+                                         out_width, images_t.dim_size(3)}),
+                            &output_t));
+    auto output = output_t->tensor<T, 4>();
+    auto images = images_t.tensor<T, 4>();
+    auto transform = transform_t.matrix<float>();
+
+    (FillProjectiveTransform<Device, T>(interpolation_))(
+        ctx->eigen_device<Device>(), &output, images, transform);
+  }
+};
+
+#define REGISTER(TYPE)                                                \
+  REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2")          \
+                              .Device(DEVICE_CPU)                     \
+                              .TypeConstraint<TYPE>("dtype"),         \
+                          ImageProjectiveTransform<CPUDevice, TYPE>)
+
+TF_CALL_uint8(REGISTER);
+TF_CALL_int32(REGISTER);
+TF_CALL_int64(REGISTER);
+TF_CALL_half(REGISTER);
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+
+#undef REGISTER
+
+#if GOOGLE_CUDA
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+// NOTE(ringwalt): We get an undefined symbol error if we don't explicitly
+// instantiate the operator() in GCC'd code.
+#define DECLARE_FUNCTOR(TYPE)                                               \
+  template <>                                                               \
+  void FillProjectiveTransform<GPUDevice, TYPE>::operator()(                \
+      const GPUDevice& device, OutputType* output, const InputType& images, \
+      const TransformsType& transform) const;                               \
+  extern template struct FillProjectiveTransform<GPUDevice, TYPE>
+
+TF_CALL_uint8(DECLARE_FUNCTOR);
+TF_CALL_int32(DECLARE_FUNCTOR);
+TF_CALL_int64(DECLARE_FUNCTOR);
+TF_CALL_half(DECLARE_FUNCTOR);
+TF_CALL_float(DECLARE_FUNCTOR);
+TF_CALL_double(DECLARE_FUNCTOR);
+
+}  // end namespace functor
+
+#define REGISTER(TYPE)                                                \
+  REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2")          \
+                              .Device(DEVICE_GPU)                     \
+                              .TypeConstraint<TYPE>("dtype")          \
+                              .HostMemory("output_shape"),            \
+                          ImageProjectiveTransform<GPUDevice, TYPE>)
+
+TF_CALL_uint8(REGISTER);
+TF_CALL_int32(REGISTER);
+TF_CALL_int64(REGISTER);
+TF_CALL_half(REGISTER);
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+
+#undef REGISTER
+
+#endif  // GOOGLE_CUDA
+
+}  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/image_ops.h b/tensorflow/core/kernels/image_ops.h
new file mode 100644
index 00000000000..4e375a67184
--- /dev/null
+++ b/tensorflow/core/kernels/image_ops.h
@@ -0,0 +1,172 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_IMAGE_OPS_H_
+
+// See docs in ../ops/image_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace generator {
+
+enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
+
+using Eigen::array;
+using Eigen::DenseIndex;
+
+template <typename Device, typename T>
+class ProjectiveGenerator {
+ private:
+  typename TTypes<T, 4>::ConstTensor input_;
+  typename TTypes<float>::ConstMatrix transforms_;
+  const Interpolation interpolation_;
+
+ public:
+  static const int kNumParameters = 8;
+
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+  ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
+                      typename TTypes<float>::ConstMatrix transforms,
+                      const Interpolation interpolation)
+      : input_(input), transforms_(transforms), interpolation_(interpolation) {}
+
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
+  operator()(const array<DenseIndex, 4>& coords) const {
+    const int64 output_y = coords[1];
+    const int64 output_x = coords[2];
+    const float* transform =
+        transforms_.dimension(0) == 1
+            ? transforms_.data()
+            : &transforms_.data()[transforms_.dimension(1) * coords[0]];
+    float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
+    if (projection == 0) {
+      // Return the fill value (0) for infinite coordinates,
+      // which are outside the input image
+      return T(0);
+    }
+    const float input_x =
+        (transform[0] * output_x + transform[1] * output_y + transform[2]) /
+        projection;
+    const float input_y =
+        (transform[3] * output_x + transform[4] * output_y + transform[5]) /
+        projection;
+
+    const T fill_value = T(0);
+    switch (interpolation_) {
+      case INTERPOLATION_NEAREST:
+        // Switch the order of x and y again for indexing into the image.
+        return nearest_interpolation(coords[0], input_y, input_x, coords[3],
+                                     fill_value);
+      case INTERPOLATION_BILINEAR:
+        return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
+                                      fill_value);
+    }
+    // Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST
+    // or INTERPOLATION_BILINEAR.
+    return T(0);
+  }
+
+ private:
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
+  nearest_interpolation(const DenseIndex batch, const float y, const float x,
+                        const DenseIndex channel, const T fill_value) const {
+    return read_with_fill_value(batch, DenseIndex(std::round(y)),
+                                DenseIndex(std::round(x)), channel, fill_value);
+  }
+
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
+  bilinear_interpolation(const DenseIndex batch, const float y, const float x,
+                         const DenseIndex channel, const T fill_value) const {
+    const float y_floor = std::floor(y);
+    const float x_floor = std::floor(x);
+    const float y_ceil = y_floor + 1;
+    const float x_ceil = x_floor + 1;
+    // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
+    //               + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
+    const float value_yfloor =
+        (x_ceil - x) * static_cast<float>(read_with_fill_value(
+                           batch, DenseIndex(y_floor), DenseIndex(x_floor),
+                           channel, fill_value)) +
+        (x - x_floor) * static_cast<float>(read_with_fill_value(
+                            batch, DenseIndex(y_floor), DenseIndex(x_ceil),
+                            channel, fill_value));
+    // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
+    //              + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
+    const float value_yceil =
+        (x_ceil - x) * static_cast<float>(read_with_fill_value(
+                           batch, DenseIndex(y_ceil), DenseIndex(x_floor),
+                           channel, fill_value)) +
+        (x - x_floor) * static_cast<float>(read_with_fill_value(
+                            batch, DenseIndex(y_ceil), DenseIndex(x_ceil),
+                            channel, fill_value));
+    // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
+    //         + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
+    return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
+  }
+
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
+      const DenseIndex batch, const DenseIndex y, const DenseIndex x,
+      const DenseIndex channel, const T fill_value) const {
+    // batch and channel must be correct, because they are passed unchanged from
+    // the input.
+    return (0 <= y && y < input_.dimension(1) && 0 <= x &&
+            x < input_.dimension(2))
+               ? input_(array<DenseIndex, 4>{batch, y, x, channel})
+               : fill_value;
+  }
+};
+
+}  // end namespace generator
+
+// NOTE(ringwalt): We MUST wrap the generate() call in a functor and explicitly
+// instantiate the functor in image_ops_gpu.cu.cc. Otherwise, we will be missing
+// some Eigen device code.
+namespace functor {
+
+using generator::Interpolation;
+using generator::ProjectiveGenerator;
+
+template <typename Device, typename T>
+struct FillProjectiveTransform {
+  typedef typename TTypes<T, 4>::Tensor OutputType;
+  typedef typename TTypes<T, 4>::ConstTensor InputType;
+  typedef typename TTypes<float, 2>::ConstTensor TransformsType;
+  const Interpolation interpolation_;
+
+  FillProjectiveTransform(Interpolation interpolation)
+      : interpolation_(interpolation) {}
+
+  EIGEN_ALWAYS_INLINE
+  void operator()(const Device& device, OutputType* output,
+                  const InputType& images,
+                  const TransformsType& transform) const {
+    output->device(device) = output->generate(
+        ProjectiveGenerator<Device, T>(images, transform, interpolation_));
+  }
+};
+
+}  // end namespace functor
+
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_IMAGE_OPS_H_
diff --git a/tensorflow/core/kernels/image_ops_gpu.cu.cc b/tensorflow/core/kernels/image_ops_gpu.cu.cc
new file mode 100644
index 00000000000..827fb493e4c
--- /dev/null
+++ b/tensorflow/core/kernels/image_ops_gpu.cu.cc
@@ -0,0 +1,43 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/image_ops.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+// Explicit instantiation of the GPU functor.
+typedef Eigen::GpuDevice GPUDevice;
+
+template class FillProjectiveTransform<GPUDevice, uint8>;
+template class FillProjectiveTransform<GPUDevice, int32>;
+template class FillProjectiveTransform<GPUDevice, int64>;
+template class FillProjectiveTransform<GPUDevice, Eigen::half>;
+template class FillProjectiveTransform<GPUDevice, float>;
+template class FillProjectiveTransform<GPUDevice, double>;
+
+}  // end namespace functor
+
+}  // end namespace tensorflow
+
+#endif  // GOOGLE_CUDA
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 57c032fbf37..a366d57c76f 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -1021,4 +1021,22 @@ REGISTER_OP("GenerateBoundingBoxProposals")
       c->set_output(1, prob_shape);
       return Status::OK();
     });
+
+// TODO(ringwalt): Add a "fill_mode" attr with "constant", "mirror", etc.
+// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
+// V2 op supports output_shape. V1 op is in contrib.
+REGISTER_OP("ImageProjectiveTransformV2")
+    .Input("images: dtype")
+    .Input("transforms: float32")
+    .Input("output_shape: int32")
+    .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
+    .Attr("interpolation: string")
+    .Output("transformed_images: dtype")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle input;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+      return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
+                                   c->Dim(input, 3));
+    });
+
 }  // namespace tensorflow
diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
index 4504505cd60..a6a1ccfcc0b 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras.engine.base_layer import Layer
 from tensorflow.python.keras.engine.input_spec import InputSpec
@@ -28,7 +30,7 @@ from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import image_ops_impl as image_ops
+from tensorflow.python.ops import image_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import stateful_random_ops
 from tensorflow.python.ops import stateless_random_ops
@@ -382,6 +384,218 @@ class RandomFlip(Layer):
     return dict(list(base_config.items()) + list(config.items()))
 
 
+class RandomTranslation(Layer):
+  """Randomly translate each image during training.
+
+  Arguments:
+    height_factor: a positive float represented as fraction of value, or a tuple
+      of size 2 representing lower and upper bound for shifting vertically. When
+      represented as a single float, this value is used for both the upper and
+      lower bound. For instance, `height_factor=(0.2, 0.3)` results in an output
+      height varying in the range `[original - 20%, original + 30%]`.
+      `height_factor=0.2` results in an output height varying in the range
+      `[original - 20%, original + 20%]`.
+    width_factor: a positive float represented as fraction of value, or a tuple
+      of size 2 representing lower and upper bound for shifting horizontally.
+      When represented as a single float, this value is used for both the upper
+      and lower bound.
+    fill_mode: Points outside the boundaries of the input are filled according
+      to the given mode (one of `{'nearest', 'bilinear'}`).
+    fill_value: Value used for points outside the boundaries of the input if
+      `mode='constant'`.
+    seed: Integer. Used to create a random seed.
+  Input shape:
+    4D tensor with shape: `(samples, height, width, channels)`,
+      data_format='channels_last'.
+  Output shape:
+    4D tensor with shape: `(samples, height, width, channels)`,
+      data_format='channels_last'.
+  Raise:
+    ValueError: if lower bound is not between [0, 1], or upper bound is
+      negative.
+  """
+
+  def __init__(self,
+               height_factor,
+               width_factor,
+               fill_mode='nearest',
+               fill_value=0.,
+               seed=None,
+               **kwargs):
+    self.height_factor = height_factor
+    if isinstance(height_factor, (tuple, list)):
+      self.height_lower = abs(height_factor[0])
+      self.height_upper = height_factor[1]
+    else:
+      self.height_lower = self.height_upper = height_factor
+    if self.height_upper < 0.:
+      raise ValueError('`height_factor` cannot have negative values as upper '
+                       'bound, got {}'.format(height_factor))
+    if abs(self.height_lower) > 1. or abs(self.height_upper) > 1.:
+      raise ValueError('`height_factor` must have values between [-1, 1], '
+                       'got {}'.format(height_factor))
+
+    self.width_factor = width_factor
+    if isinstance(width_factor, (tuple, list)):
+      self.width_lower = abs(width_factor[0])
+      self.width_upper = width_factor[1]
+    else:
+      self.width_lower = self.width_upper = width_factor
+    if self.width_upper < 0.:
+      raise ValueError('`width_factor` cannot have negative values as upper '
+                       'bound, got {}'.format(width_factor))
+    if abs(self.width_lower) > 1. or abs(self.width_upper) > 1.:
+      raise ValueError('`width_factor` must have values between [-1, 1], '
+                       'got {}'.format(width_factor))
+
+    if fill_mode not in {'nearest', 'bilinear'}:
+      raise NotImplementedError(
+          '`fill_mode` {} is not supported yet.'.format(fill_mode))
+    self.fill_mode = fill_mode
+    self.fill_value = fill_value
+    self.seed = seed
+    self._rng = make_generator(self.seed)
+    self.input_spec = InputSpec(ndim=4)
+    super(RandomTranslation, self).__init__(**kwargs)
+
+  def call(self, inputs, training=None):
+    if training is None:
+      training = K.learning_phase()
+
+    def random_translated_inputs():
+      """Translated inputs with random ops."""
+      inputs_shape = array_ops.shape(inputs)
+      batch_size = inputs_shape[0]
+      h_axis, w_axis = 1, 2
+      img_hd = math_ops.cast(inputs_shape[h_axis], dtypes.float32)
+      img_wd = math_ops.cast(inputs_shape[w_axis], dtypes.float32)
+      height_translate = self._rng.uniform(
+          shape=[batch_size, 1],
+          minval=-self.height_lower,
+          maxval=self.height_upper)
+      height_translate = height_translate * img_hd
+      width_translate = self._rng.uniform(
+          shape=[batch_size, 1],
+          minval=-self.width_lower,
+          maxval=self.width_upper)
+      width_translate = width_translate * img_wd
+      translations = math_ops.cast(
+          array_ops.concat([height_translate, width_translate], axis=1),
+          dtype=inputs.dtype)
+      return transform(
+          inputs,
+          get_translation_matrix(translations),
+          interpolation=self.fill_mode)
+
+    output = tf_utils.smart_cond(training, random_translated_inputs,
+                                 lambda: inputs)
+    output.set_shape(inputs.shape)
+    return output
+
+  def compute_output_shape(self, input_shape):
+    return input_shape
+
+  def get_config(self):
+    config = {
+        'height_factor': self.height_factor,
+        'width_factor': self.width_factor,
+        'fill_mode': self.fill_mode,
+        'fill_value': self.fill_value,
+        'seed': self.seed,
+    }
+    base_config = super(RandomTranslation, self).get_config()
+    return dict(list(base_config.items()) + list(config.items()))
+
+
+def get_translation_matrix(translations, name=None):
+  """Returns projective transform(s) for the given translation(s).
+
+  Args:
+    translations: A matrix of 2-element lists representing [dx, dy] to translate
+      for each image (for a batch of images).
+    name: The name of the op.
+
+  Returns:
+    A tensor of shape (num_images, 8) projective transforms which can be given
+      to `transform`.
+  """
+  with ops.name_scope(name, 'translation_matrix'):
+    num_translations = array_ops.shape(translations)[0]
+    # The translation matrix looks like:
+    #     [[1 0 -dx]
+    #      [0 1 -dy]
+    #      [0 0 1]]
+    # where the last entry is implicit.
+    # Translation matrices are always float32.
+    return array_ops.concat(
+        values=[
+            array_ops.ones((num_translations, 1), dtypes.float32),
+            array_ops.zeros((num_translations, 1), dtypes.float32),
+            -translations[:, 0, None],
+            array_ops.zeros((num_translations, 1), dtypes.float32),
+            array_ops.ones((num_translations, 1), dtypes.float32),
+            -translations[:, 1, None],
+            array_ops.zeros((num_translations, 2), dtypes.float32),
+        ],
+        axis=1)
+
+
+def transform(images,
+              transforms,
+              interpolation='nearest',
+              output_shape=None,
+              name=None):
+  """Applies the given transform(s) to the image(s).
+
+  Args:
+    images: A tensor of shape (num_images, num_rows, num_columns, num_channels)
+      (NHWC), (num_rows, num_columns, num_channels) (HWC), or (num_rows,
+      num_columns) (HW). The rank must be statically known (the shape is not
+      `TensorShape(None)`.
+    transforms: Projective transform matrix/matrices. A vector of length 8 or
+      tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1, b2,
+      c0, c1], then it maps the *output* point `(x, y)` to a transformed *input*
+      point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where
+      `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to the
+      transform mapping input points to output points. Note that gradients are
+      not backpropagated into transformation parameters.
+    interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
+    output_shape: Output dimesion after the transform, [height, width]. If None,
+      output is the same size as input image.
+    name: The name of the op.
+
+  Returns:
+    Image(s) with the same type and shape as `images`, with the given
+    transform(s) applied. Transformed coordinates outside of the input image
+    will be filled with zeros.
+
+  Raises:
+    TypeError: If `image` is an invalid type.
+    ValueError: If output shape is not 1-D int32 Tensor.
+  """
+  with ops.name_scope(name, 'transform'):
+    if output_shape is None:
+      output_shape = array_ops.shape(images)[1:3]
+      if not context.executing_eagerly():
+        output_shape_value = tensor_util.constant_value(output_shape)
+        if output_shape_value is not None:
+          output_shape = output_shape_value
+
+    output_shape = ops.convert_to_tensor(
+        output_shape, dtypes.int32, name='output_shape')
+
+    if not output_shape.get_shape().is_compatible_with([2]):
+      raise ValueError('output_shape must be a 1-D Tensor of 2 elements: '
+                       'new_height, new_width, instead got '
+                       '{}'.format(output_shape))
+
+    return image_ops.image_projective_transform_v2(
+        images,
+        output_shape=output_shape,
+        transforms=transforms,
+        interpolation=interpolation.upper())
+
+
 class RandomContrast(Layer):
   """Adjust the contrast of an image or images by a random factor.
 
diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
index 3249136753e..6710cdaf78f 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.layers.preprocessing import image_preprocessing
 from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
+from tensorflow.python.ops import gen_stateful_random_ops
 from tensorflow.python.ops import image_ops_impl as image_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import stateless_random_ops
@@ -459,5 +460,58 @@ class RandomContrastTest(keras_parameterized.TestCase):
     self.assertEqual(layer_1.name, layer.name)
 
 
+@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+class RandomTranslationTest(keras_parameterized.TestCase):
+
+  def _run_test(self, height_factor, width_factor):
+    np.random.seed(1337)
+    num_samples = 2
+    orig_height = 5
+    orig_width = 8
+    channels = 3
+    kwargs = {'height_factor': height_factor, 'width_factor': width_factor}
+    with tf_test_util.use_gpu():
+      testing_utils.layer_test(
+          image_preprocessing.RandomTranslation,
+          kwargs=kwargs,
+          input_shape=(num_samples, orig_height, orig_width, channels),
+          expected_output_shape=(None, orig_height, orig_width, channels))
+
+  @parameterized.named_parameters(
+      ('random_translate_4_by_6', .4, .6), ('random_translate_3_by_2', .3, .2),
+      ('random_translate_tuple_factor', (.5, .4), (.2, .3)))
+  def test_random_translation(self, height_factor, width_factor):
+    self._run_test(height_factor, width_factor)
+
+  def test_random_translation_negative_lower(self):
+    mock_offset = np.random.random((12, 1))
+    with test.mock.patch.object(
+        gen_stateful_random_ops, 'stateful_uniform', return_value=mock_offset):
+      with self.cached_session(use_gpu=True):
+        layer = image_preprocessing.RandomTranslation((-0.2, .3), .4)
+        layer_2 = image_preprocessing.RandomTranslation((0.2, .3), .4)
+        inp = np.random.random((12, 5, 8, 3)).astype(np.float32)
+        actual_output = layer(inp, training=1)
+        actual_output_2 = layer_2(inp, training=1)
+        self.assertAllClose(actual_output, actual_output_2)
+
+  def test_random_translation_inference(self):
+    with CustomObjectScope(
+        {'RandomTranslation': image_preprocessing.RandomTranslation}):
+      input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
+      expected_output = input_images
+      with tf_test_util.use_gpu():
+        layer = image_preprocessing.RandomTranslation(.5, .5)
+        actual_output = layer(input_images, training=0)
+        self.assertAllClose(expected_output, actual_output)
+
+  @tf_test_util.run_v2_only
+  def test_config_with_custom_name(self):
+    layer = image_preprocessing.RandomTranslation(.5, .6, name='image_preproc')
+    config = layer.get_config()
+    layer_1 = image_preprocessing.RandomTranslation.from_config(config)
+    self.assertEqual(layer_1.name, layer.name)
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 3de46e7cf3f..8f5a38fb9ab 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -22,7 +22,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_image_ops
+from tensorflow.python.ops import linalg_ops
 # go/tf-wildcard-import
 # pylint: disable=wildcard-import
 from tensorflow.python.ops.gen_image_ops import *
@@ -34,3 +39,105 @@ from tensorflow.python.ops.image_ops_impl import *
 from tensorflow.python.ops.image_ops_impl import _Check3DImage
 from tensorflow.python.ops.image_ops_impl import _ImageDimensions
 # pylint: enable=unused-import
+
+_IMAGE_DTYPES = frozenset([
+    dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32,
+    dtypes.float64
+])
+
+
+def flat_transforms_to_matrices(transforms):
+  """Converts `tf.contrib.image` projective transforms to affine matrices.
+
+  Note that the output matrices map output coordinates to input coordinates. For
+  the forward transformation matrix, call `tf.linalg.inv` on the result.
+
+  Args:
+    transforms: Vector of length 8, or batches of transforms with shape `(N,
+      8)`.
+
+  Returns:
+    3D tensor of matrices with shape `(N, 3, 3)`. The output matrices map the
+      *output coordinates* (in homogeneous coordinates) of each transform to the
+      corresponding *input coordinates*.
+
+  Raises:
+    ValueError: If `transforms` have an invalid shape.
+  """
+  with ops.name_scope("flat_transforms_to_matrices"):
+    transforms = ops.convert_to_tensor(transforms, name="transforms")
+    if transforms.shape.ndims not in (1, 2):
+      raise ValueError("Transforms should be 1D or 2D, got: %s" % transforms)
+    # Make the transform(s) 2D in case the input is a single transform.
+    transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8]))
+    num_transforms = array_ops.shape(transforms)[0]
+    # Add a column of ones for the implicit last entry in the matrix.
+    return array_ops.reshape(
+        array_ops.concat(
+            [transforms, array_ops.ones([num_transforms, 1])], axis=1),
+        constant_op.constant([-1, 3, 3]))
+
+
+def matrices_to_flat_transforms(transform_matrices):
+  """Converts affine matrices to `tf.contrib.image` projective transforms.
+
+  Note that we expect matrices that map output coordinates to input coordinates.
+  To convert forward transformation matrices, call `tf.linalg.inv` on the
+  matrices and use the result here.
+
+  Args:
+    transform_matrices: One or more affine transformation matrices, for the
+      reverse transformation in homogeneous coordinates. Shape `(3, 3)` or `(N,
+      3, 3)`.
+
+  Returns:
+    2D tensor of flat transforms with shape `(N, 8)`, which may be passed into
+      `tf.contrib.image.transform`.
+
+  Raises:
+    ValueError: If `transform_matrices` have an invalid shape.
+  """
+  with ops.name_scope("matrices_to_flat_transforms"):
+    transform_matrices = ops.convert_to_tensor(
+        transform_matrices, name="transform_matrices")
+    if transform_matrices.shape.ndims not in (2, 3):
+      raise ValueError("Matrices should be 2D or 3D, got: %s" %
+                       transform_matrices)
+    # Flatten each matrix.
+    transforms = array_ops.reshape(transform_matrices,
+                                   constant_op.constant([-1, 9]))
+    # Divide each matrix by the last entry (normally 1).
+    transforms /= transforms[:, 8:9]
+    return transforms[:, :8]
+
+
+@ops.RegisterGradient("ImageProjectiveTransformV2")
+def _image_projective_transform_grad(op, grad):
+  """Computes the gradient for ImageProjectiveTransform."""
+  images = op.inputs[0]
+  transforms = op.inputs[1]
+  interpolation = op.get_attr("interpolation")
+
+  image_or_images = ops.convert_to_tensor(images, name="images")
+  transform_or_transforms = ops.convert_to_tensor(
+      transforms, name="transforms", dtype=dtypes.float32)
+
+  if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
+    raise TypeError("Invalid dtype %s." % image_or_images.dtype)
+  if len(transform_or_transforms.get_shape()) == 1:
+    transforms = transform_or_transforms[None]
+  elif len(transform_or_transforms.get_shape()) == 2:
+    transforms = transform_or_transforms
+  else:
+    raise TypeError("Transforms should have rank 1 or 2.")
+
+  # Invert transformations
+  transforms = flat_transforms_to_matrices(transforms=transforms)
+  inverse = linalg_ops.matrix_inverse(transforms)
+  transforms = matrices_to_flat_transforms(inverse)
+  output = gen_image_ops.image_projective_transform_v2(
+      images=grad,
+      transforms=transforms,
+      output_shape=array_ops.shape(image_or_images)[1:3],
+      interpolation=interpolation)
+  return [output, None, None]
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 5f29cd1cd33..e6b34d45b35 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1764,6 +1764,10 @@ tf_module {
     name: "Imag"
     argspec: "args=[\'input\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
   }
+  member_method {
+    name: "ImageProjectiveTransformV2"
+    argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "ImageSummary"
     argspec: "args=[\'tag\', \'tensor\', \'max_images\', \'bad_color\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'dtype: DT_UINT8\\ntensor_shape {\\n  dim {\\n    size: 4\\n  }\\n}\\nint_val: 255\\nint_val: 0\\nint_val: 0\\nint_val: 255\\n\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 5f29cd1cd33..e6b34d45b35 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1764,6 +1764,10 @@ tf_module {
     name: "Imag"
     argspec: "args=[\'input\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
   }
+  member_method {
+    name: "ImageProjectiveTransformV2"
+    argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "ImageSummary"
     argspec: "args=[\'tag\', \'tensor\', \'max_images\', \'bad_color\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'dtype: DT_UINT8\\ntensor_shape {\\n  dim {\\n    size: 4\\n  }\\n}\\nint_val: 255\\nint_val: 0\\nint_val: 0\\nint_val: 255\\n\', \'None\'], "