diff --git a/tensorflow/core/kernels/attention_ops.cc b/tensorflow/core/kernels/attention_ops.cc
index f555c0fd679..6e5e07a9fb1 100644
--- a/tensorflow/core/kernels/attention_ops.cc
+++ b/tensorflow/core/kernels/attention_ops.cc
@@ -32,6 +32,8 @@ namespace tensorflow {
 class ExtractGlimpseOp : public OpKernel {
  public:
   explicit ExtractGlimpseOp(OpKernelConstruction* context) : OpKernel(context) {
+    const string& op = context->def().op();
+    version_ = (op == "ExtractGlimpse") ? 1 : 2;
     OP_REQUIRES_OK(context, context->GetAttr("normalized", &normalized_));
     OP_REQUIRES_OK(context, context->GetAttr("centered", &centered_));
     bool uniform_noise = false;
@@ -117,21 +119,23 @@ class ExtractGlimpseOp : public OpKernel {
       // calling TensorFlow operates with (y,x) as indices.
       offset_vec.push_back(Eigen::IndexPair<float>(offset_x, offset_y));
     }
-
     output->tensor<float, 4>().swap_layout().device(
         context->eigen_cpu_device()) =
         Eigen::ExtractGlimpses(input.tensor<float, 4>().swap_layout(),
                                output_width, output_height, offset_vec,
-                               normalized_, centered_, noise_);
+                               normalized_, centered_, noise_, version_);
   }
 
  private:
   bool normalized_;
   bool centered_;
   Eigen::ExtractGlimpsesNoiseMode noise_;
+  int32 version_;
 };
 
 REGISTER_KERNEL_BUILDER(Name("ExtractGlimpse").Device(DEVICE_CPU),
                         ExtractGlimpseOp);
+REGISTER_KERNEL_BUILDER(Name("ExtractGlimpseV2").Device(DEVICE_CPU),
+                        ExtractGlimpseOp);
 
 }  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/eigen_attention.h b/tensorflow/core/kernels/eigen_attention.h
index 7cf5c53dfca..ca61e223c21 100644
--- a/tensorflow/core/kernels/eigen_attention.h
+++ b/tensorflow/core/kernels/eigen_attention.h
@@ -56,13 +56,15 @@ struct GlimpseExtractionOp {
   GlimpseExtractionOp(const Index width, const Index height,
                       const std::vector<IndexPair<float> >& offsets,
                       const bool normalized, const bool centered,
-                      const ExtractGlimpsesNoiseMode noise)
+                      const ExtractGlimpsesNoiseMode noise,
+		      const int version)
       : width_(width),
         height_(height),
         offsets_(offsets),
         normalized_(normalized),
         centered_(centered),
-        noise_(noise) {}
+        noise_(noise),
+        version_(version) {}
 
   template <typename Input>
   DSizes<Index, 4> dimensions(const Input& input) const {
@@ -101,24 +103,42 @@ struct GlimpseExtractionOp {
     for (Index i = 0; i < batch_size; ++i) {
       float x = offsets_[i].first, y = offsets_[i].second;
 
-      if (normalized_) {
+      if (version_ == 1) {
         // Un-normalize coordinates back to pixel space if normalized.
-        x *= input_width;
-        y *= input_height;
+        if (normalized_) {
+          x *= input_width;
+          y *= input_height;
+        }
+        // Un-center if coordinates are centered on the image center.
         if (centered_) {
-          // Un-center if coordinates are centered on the image center.
           x /= 2.0f;
           y /= 2.0f;
           x += input_width / 2.0f;
           y += input_height / 2.0f;
-          // Remove half of the glimpse window.
-          x -= width_ / 2.0f;
-          y -= height_ / 2.0f;
         }
+        // Remove half of the glimpse window.
+        x -= width_ / 2.0f;
+        y -= height_ / 2.0f;
       } else {
-        if (centered_) {
-          x += input_width / 2.0f;
-          y += input_height / 2.0f;
+        if (normalized_) {
+          // Un-normalize coordinates back to pixel space if normalized.
+          x *= input_width;
+          y *= input_height;
+          if (centered_) {
+            // Un-center if coordinates are centered on the image center.
+            x /= 2.0f;
+            y /= 2.0f;
+            x += input_width / 2.0f;
+            y += input_height / 2.0f;
+            // Remove half of the glimpse window.
+            x -= width_ / 2.0f;
+            y -= height_ / 2.0f;
+          }
+        } else {
+          if (centered_) {
+            x += input_width / 2.0f;
+            y += input_height / 2.0f;
+          }
         }
       }
 
@@ -248,6 +268,7 @@ struct GlimpseExtractionOp {
   const bool normalized_;
   const bool centered_;
   const ExtractGlimpsesNoiseMode noise_;
+  const int version_;
 };
 }  // namespace
 
@@ -260,7 +281,8 @@ ExtractGlimpses(
     const typename internal::traits<Input>::Index height,
     const std::vector<IndexPair<float> >& offsets, const bool normalized = true,
     const bool centered = true,
-    const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM) {
+    const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM,
+    const int version = 2) {
   EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor,
                       YOU_MADE_A_PROGRAMMING_MISTAKE);
   EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4,
@@ -268,7 +290,7 @@ ExtractGlimpses(
 
   typedef typename internal::traits<Input>::Index Index;
   const GlimpseExtractionOp<Index> op(width, height, offsets, normalized,
-                                      centered, noise);
+                                      centered, noise, version);
   return input.customOp(op);
 }
 
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 418f1e20e37..e11f14b8538 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -756,6 +756,41 @@ REGISTER_OP("ExtractGlimpse")
                                    c->Dim(input, 3));
     });
 
+REGISTER_OP("ExtractGlimpseV2")
+    .Input("input: float")
+    .Input("size: int32")
+    .Input("offsets: float")
+    .Output("glimpse: float")
+    .Attr("centered: bool = true")
+    .Attr("normalized: bool = true")
+    .Attr("uniform_noise: bool = true")
+    .Attr("noise: string = 'uniform'")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle input;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+      ShapeHandle offsets;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
+
+      DimensionHandle batch_dim;
+      TF_RETURN_IF_ERROR(
+          c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
+      DimensionHandle unused;
+      TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
+
+      bool uniform_noise = false;
+      TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise));
+      string noise;
+      TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise));
+      if (uniform_noise && (!noise.empty() && noise != "uniform")) {
+        return errors::InvalidArgument(
+            "The uniform_noise and noise should not be specified at the same "
+            "time");
+      }
+
+      return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
+                                   c->Dim(input, 3));
+    });
+
 // --------------------------------------------------------------------------
 
 REGISTER_OP("CropAndResize")
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index bd0722f32f9..49f44872ebf 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -4063,10 +4063,10 @@ def extract_glimpse(
   >>> tf.image.extract_glimpse(x, size=(2, 2), offsets=[[1, 1]],
   ...                         centered=False, normalized=False)
   <tf.Tensor: shape=(1, 2, 2, 1), dtype=float32, numpy=
-  array([[[[4.],
-           [5.]],
-          [[7.],
-           [8.]]]], dtype=float32)>
+  array([[[[0.],
+           [1.]],
+          [[3.],
+           [4.]]]], dtype=float32)>
 
   Args:
     input: A `Tensor` of type `float32`. A 4-D float tensor of shape
@@ -4176,7 +4176,7 @@ def extract_glimpse_v2(
   Returns:
     A `Tensor` of type `float32`.
   """
-  return gen_image_ops.extract_glimpse(
+  return gen_image_ops.extract_glimpse_v2(
       input=input,
       size=size,
       offsets=offsets,