diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f66102ab3ac..8a241933867 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3166,6 +3166,7 @@ tf_cc_tests(
         "adjust_contrast_op_test.cc",
         "colorspace_op_test.cc",
         "crop_and_resize_op_test.cc",
+        "mirror_pad_op_test.cc",
         "non_max_suppression_op_test.cc",
         "resize_area_op_test.cc",
         "resize_bicubic_op_test.cc",
@@ -3178,6 +3179,7 @@ tf_cc_tests(
     }),
     deps = [
         ":image",
+        ":mirror_pad_op",
         ":ops_testutil",
         ":ops_util",
         ":sampling_kernels",
@@ -3244,6 +3246,22 @@ tf_cuda_cc_test(
     ],
 )
 
+tf_cuda_cc_test(
+    name = "mirror_pad_op_benchmark_test",
+    srcs = ["mirror_pad_op_benchmark_test.cc"],
+    deps = [
+        ":mirror_pad_op",
+        ":ops_testutil",
+        ":ops_util",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
 tf_cuda_cc_test(
     name = "non_max_suppression_op_gpu_test",
     srcs = ["non_max_suppression_op_gpu_test.cc"],
diff --git a/tensorflow/core/kernels/mirror_pad_op.h b/tensorflow/core/kernels/mirror_pad_op.h
index eda3b2b9e2a..b94aec9a68b 100644
--- a/tensorflow/core/kernels/mirror_pad_op.h
+++ b/tensorflow/core/kernels/mirror_pad_op.h
@@ -16,9 +16,9 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
 #define TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
 
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/tensor_types.h"
 #include "tensorflow/core/platform/types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 
 namespace Eigen {
 template <typename PaddingDimensions, typename XprType>
@@ -223,7 +223,8 @@ struct TensorEvaluator<const TensorMirrorPadOp<PaddingDimensions, ArgType>,
     const Index right =
         (dimensions_[dim] - padding_[dim].second) * output_strides_[dim];
 
-    if (left <= index && (index + kPacketSize - 1) < right) {
+    const Index index_mod = index % (dimensions_[dim] * output_strides_[dim]);
+    if (left <= index_mod && (index_mod + kPacketSize - 1) < right) {
       return impl_.template packet<Unaligned>(input_index);
     }
 
diff --git a/tensorflow/core/kernels/mirror_pad_op_benchmark_test.cc b/tensorflow/core/kernels/mirror_pad_op_benchmark_test.cc
new file mode 100644
index 00000000000..733d2350fdd
--- /dev/null
+++ b/tensorflow/core/kernels/mirror_pad_op_benchmark_test.cc
@@ -0,0 +1,59 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+static Graph* BM_MirrorPad(int batches, int height, int width, int depth,
+                           int pad, const char* mode) {
+  Graph* g = new Graph(OpRegistry::Global());
+  Tensor in(DT_FLOAT, TensorShape({batches, height, width, depth}));
+  in.flat<float>().setRandom();
+  Tensor padding(DT_INT32, TensorShape({4, 2}));
+  auto boxes_tensor = padding.flat<int>().setZero();
+  for (int i = 2; i < 6; i++) boxes_tensor(i) = pad;
+
+  Node* ret;
+  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MirrorPad")
+                  .Input(test::graph::Constant(g, in))
+                  .Input(test::graph::Constant(g, padding))
+                  .Attr("mode", mode)
+                  .Finalize(g, &ret));
+  return g;
+}
+
+#define BM_MirrorPadDev(DEVICE, B, W, H, D, P, MODE)                         \
+  static void BM_MirrorPad_##DEVICE##_##B##_##W##_##H##_##D##_##P##_##MODE(  \
+      int iters) {                                                           \
+    testing::ItemsProcessed(iters* B*(W + 2 * P) * (H + 2 * P) * D / 32);    \
+    test::Benchmark(#DEVICE, BM_MirrorPad(B, W, H, D, P, #MODE)).Run(iters); \
+  }                                                                          \
+  BENCHMARK(BM_MirrorPad_##DEVICE##_##B##_##W##_##H##_##D##_##P##_##MODE);
+
+BM_MirrorPadDev(cpu, 1, 16, 16, 32, 1, REFLECT);
+BM_MirrorPadDev(cpu, 1, 16, 16, 32, 8, REFLECT);
+BM_MirrorPadDev(cpu, 1, 512, 512, 16, 1, REFLECT);
+BM_MirrorPadDev(cpu, 1, 512, 512, 16, 256, REFLECT);
+BM_MirrorPadDev(cpu, 1, 16, 16, 32, 1, SYMMETRIC);
+BM_MirrorPadDev(cpu, 1, 16, 16, 32, 8, SYMMETRIC);
+BM_MirrorPadDev(cpu, 1, 512, 512, 16, 1, SYMMETRIC);
+BM_MirrorPadDev(cpu, 1, 512, 512, 16, 256, SYMMETRIC);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mirror_pad_op_test.cc b/tensorflow/core/kernels/mirror_pad_op_test.cc
new file mode 100644
index 00000000000..0afae5dd69b
--- /dev/null
+++ b/tensorflow/core/kernels/mirror_pad_op_test.cc
@@ -0,0 +1,201 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+class MirrorPadOpTest : public OpsTestBase {
+ protected:
+  template <typename T>
+  void MakeOp(const string& mode) {
+    TF_EXPECT_OK(NodeDefBuilder("mirror_pad_op", "MirrorPad")
+                     .Input(FakeInput(DataTypeToEnum<T>::value))
+                     .Input(FakeInput(DT_INT32))
+                     .Attr("mode", mode)
+                     .Finalize(node_def()));
+    TF_EXPECT_OK(InitOp());
+  }
+};
+
+#define REGISTER_TEST(T)                                                     \
+  TEST_F(MirrorPadOpTest, TestMirrorPadReflect##T) {                         \
+    MakeOp<T>("REFLECT");                                                    \
+    AddInputFromArray<T>(TensorShape({1, 2, 3, 1}), {1, 2, 3, 4, 5, 6});     \
+    AddInputFromArray<int32>(TensorShape({4, 2}), {0, 0, 1, 1, 2, 2, 0, 0}); \
+    TF_ASSERT_OK(RunOpKernel());                                             \
+                                                                             \
+    Tensor expected(allocator(), DataTypeToEnum<T>::value,                   \
+                    TensorShape({1, 4, 7, 1}));                              \
+    test::FillValues<T>(&expected,                                           \
+                        {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1,           \
+                         6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1});         \
+    test::ExpectTensorEqual<T>(expected, *GetOutput(0));                     \
+  }                                                                          \
+                                                                             \
+  TEST_F(MirrorPadOpTest, TestMirrorPadSymmetric##T) {                       \
+    MakeOp<T>("SYMMETRIC");                                                  \
+    AddInputFromArray<T>(TensorShape({1, 2, 1, 3}), {1, 2, 3, 4, 5, 6});     \
+    AddInputFromArray<int32>(TensorShape({4, 2}), {1, 1, 0, 0, 0, 0, 2, 2}); \
+    TF_ASSERT_OK(RunOpKernel());                                             \
+                                                                             \
+    Tensor expected(allocator(), DataTypeToEnum<T>::value,                   \
+                    TensorShape({3, 2, 1, 7}));                              \
+    test::FillValues<T>(                                                     \
+        &expected,                                                           \
+        {2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2,      \
+         5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5});    \
+    test::ExpectTensorEqual<T>(expected, *GetOutput(0));                     \
+  }
+
+REGISTER_TEST(float)
+REGISTER_TEST(double)
+REGISTER_TEST(uint8)
+REGISTER_TEST(uint16)
+REGISTER_TEST(int8)
+REGISTER_TEST(int16)
+REGISTER_TEST(int32)
+REGISTER_TEST(int64)
+
+#undef REGISTER_TEST
+
+TEST_F(MirrorPadOpTest, TestMirrorPadReflectLargeInput) {
+  MakeOp<float>("REFLECT");
+  // Generate a relatively large input
+  const int kInput = 1000;
+  const int kPad = 10;
+  const int kOutput = kInput + 2 * kPad;
+
+  // Input:
+  //  0, 1, 2, ..., 999
+  //  0, 1, 2, ..., 999
+  //  ... (altogether 1000 lines)
+  //  0, 1, 2, ..., 999
+  AddInput<float>(TensorShape({1, kInput, kInput, 1}),
+                  [](int i) -> float { return i % kInput; });
+  AddInputFromArray<int32>(TensorShape({4, 2}),
+                           {0, 0, kPad, kPad, kPad, kPad, 0, 0});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({1, kOutput, kOutput, 1}));
+  test::FillFn<float>(&expected, [](int i) -> float {
+    i = i % kOutput;
+    if (0 <= i && i < kPad)
+      return kPad - i;
+    else if (kPad <= i && i < kInput + kPad)
+      return i - kPad;
+    else if (kInput + kPad <= i && i < kOutput)
+      return 2 * kInput + kPad - 2 - i;
+  });
+
+  test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(MirrorPadOpTest, TestMirrorPadSymmetricLargeInput) {
+  MakeOp<float>("SYMMETRIC");
+  // Generate a relatively large input
+  const int kInput = 1000;
+  const int kPad = 10;
+  const int kOutput = kInput + 2 * kPad;
+
+  // Input:
+  //  0, 1, 2, ..., 999
+  //  0, 1, 2, ..., 999
+  //  ... (altogether 1000 lines)
+  //  0, 1, 2, ..., 999
+  AddInput<float>(TensorShape({1, kInput, kInput, 1}),
+                  [](int i) -> float { return i % kInput; });
+  AddInputFromArray<int32>(TensorShape({4, 2}),
+                           {0, 0, kPad, kPad, kPad, kPad, 0, 0});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({1, kOutput, kOutput, 1}));
+  test::FillFn<float>(&expected, [](int i) -> float {
+    i = i % kOutput;
+    if (0 <= i && i < kPad)
+      return kPad - i - 1;
+    else if (kPad <= i && i < kInput + kPad)
+      return i - kPad;
+    else if (kInput + kPad <= i && i < kOutput)
+      return 2 * kInput + kPad - 1 - i;
+  });
+
+  test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+class MirrorPadGradOpTest : public OpsTestBase {
+ protected:
+  template <typename T>
+  void MakeOp(const string& mode) {
+    TF_EXPECT_OK(NodeDefBuilder("mirror_pad_grad_op", "MirrorPadGrad")
+                     .Input(FakeInput(DataTypeToEnum<T>::value))
+                     .Input(FakeInput(DT_INT32))
+                     .Attr("mode", mode)
+                     .Finalize(node_def()));
+    TF_EXPECT_OK(InitOp());
+  }
+};
+
+#define REGISTER_TEST(T)                                                      \
+  TEST_F(MirrorPadGradOpTest, TestMirrorPadGradReflect##T) {                  \
+    MakeOp<T>("REFLECT");                                                     \
+    AddInput<T>(TensorShape({1, 4, 7, 1}), [](int i) -> T { return i % 7; }); \
+    AddInputFromArray<int32>(TensorShape({4, 2}), {0, 0, 1, 1, 2, 2, 0, 0});  \
+    TF_ASSERT_OK(RunOpKernel());                                              \
+                                                                              \
+    Tensor expected(allocator(), DataTypeToEnum<T>::value,                    \
+                    TensorShape({1, 2, 3, 1}));                               \
+    test::FillValues<T>(&expected, {16, 18, 8, 16, 18, 8});                   \
+    test::ExpectTensorEqual<T>(expected, *GetOutput(0));                      \
+  }                                                                           \
+                                                                              \
+  TEST_F(MirrorPadGradOpTest, TestMirrorPadGradSymmetric##T) {                \
+    MakeOp<T>("SYMMETRIC");                                                   \
+    AddInput<T>(TensorShape({3, 2, 1, 7}), [](int i) -> T { return i % 7; }); \
+    AddInputFromArray<int32>(TensorShape({4, 2}), {1, 1, 0, 0, 0, 0, 2, 2});  \
+    TF_ASSERT_OK(RunOpKernel());                                              \
+                                                                              \
+    Tensor expected(allocator(), DataTypeToEnum<T>::value,                    \
+                    TensorShape({1, 2, 1, 3}));                               \
+    test::FillValues<T>(&expected, {9, 27, 27, 9, 27, 27});                   \
+    test::ExpectTensorEqual<T>(expected, *GetOutput(0));                      \
+  }
+
+REGISTER_TEST(float)
+REGISTER_TEST(double)
+REGISTER_TEST(uint8)
+REGISTER_TEST(uint16)
+REGISTER_TEST(int8)
+REGISTER_TEST(int16)
+REGISTER_TEST(int32)
+REGISTER_TEST(int64)
+
+#undef REGISTER_TEST
+
+}  // namespace tensorflow