diff --git a/tensorflow/core/kernels/reduction_ops_all.cc b/tensorflow/core/kernels/reduction_ops_all.cc
index 41abc2b9574..4a34c4ef513 100644
--- a/tensorflow/core/kernels/reduction_ops_all.cc
+++ b/tensorflow/core/kernels/reduction_ops_all.cc
@@ -22,7 +22,13 @@ REGISTER_KERNEL_BUILDER(
         .TypeConstraint<int32>("Tidx")
         .Device(DEVICE_CPU)
         .HostMemory("reduction_indices"),
-    ReductionOp<CPUDevice, bool, Eigen::internal::AndReducer>);
+    ReductionOp<CPUDevice, bool, int32, Eigen::internal::AndReducer>);
+REGISTER_KERNEL_BUILDER(
+    Name("All")
+        .TypeConstraint<int64>("Tidx")
+        .Device(DEVICE_CPU)
+        .HostMemory("reduction_indices"),
+    ReductionOp<CPUDevice, bool, int64, Eigen::internal::AndReducer>);
 
 #if GOOGLE_CUDA
 REGISTER_KERNEL_BUILDER(
@@ -30,7 +36,13 @@ REGISTER_KERNEL_BUILDER(
         .TypeConstraint<int32>("Tidx")
         .Device(DEVICE_GPU)
         .HostMemory("reduction_indices"),
-    ReductionOp<GPUDevice, bool, Eigen::internal::AndReducer>);
+    ReductionOp<GPUDevice, bool, int32, Eigen::internal::AndReducer>);
+REGISTER_KERNEL_BUILDER(
+    Name("All")
+        .TypeConstraint<int64>("Tidx")
+        .Device(DEVICE_GPU)
+        .HostMemory("reduction_indices"),
+    ReductionOp<GPUDevice, bool, int64, Eigen::internal::AndReducer>);
 #endif
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_any.cc b/tensorflow/core/kernels/reduction_ops_any.cc
index a2087cc3b7b..6c0519de95e 100644
--- a/tensorflow/core/kernels/reduction_ops_any.cc
+++ b/tensorflow/core/kernels/reduction_ops_any.cc
@@ -22,7 +22,13 @@ REGISTER_KERNEL_BUILDER(
         .TypeConstraint<int32>("Tidx")
         .Device(DEVICE_CPU)
         .HostMemory("reduction_indices"),
-    ReductionOp<CPUDevice, bool, Eigen::internal::OrReducer>);
+    ReductionOp<CPUDevice, bool, int32, Eigen::internal::OrReducer>);
+REGISTER_KERNEL_BUILDER(
+    Name("Any")
+        .TypeConstraint<int64>("Tidx")
+        .Device(DEVICE_CPU)
+        .HostMemory("reduction_indices"),
+    ReductionOp<CPUDevice, bool, int64, Eigen::internal::OrReducer>);
 
 #if GOOGLE_CUDA
 REGISTER_KERNEL_BUILDER(
@@ -30,7 +36,13 @@ REGISTER_KERNEL_BUILDER(
         .TypeConstraint<int32>("Tidx")
         .Device(DEVICE_GPU)
         .HostMemory("reduction_indices"),
-    ReductionOp<GPUDevice, bool, Eigen::internal::OrReducer>);
+    ReductionOp<GPUDevice, bool, int32, Eigen::internal::OrReducer>);
+REGISTER_KERNEL_BUILDER(
+    Name("Any")
+        .TypeConstraint<int64>("Tidx")
+        .Device(DEVICE_GPU)
+        .HostMemory("reduction_indices"),
+    ReductionOp<GPUDevice, bool, int64, Eigen::internal::OrReducer>);
 #endif
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_common.cc b/tensorflow/core/kernels/reduction_ops_common.cc
index 5eba4288acc..8daab0d6be4 100644
--- a/tensorflow/core/kernels/reduction_ops_common.cc
+++ b/tensorflow/core/kernels/reduction_ops_common.cc
@@ -57,13 +57,12 @@ gtl::InlinedVector<int32, 8> ReductionHelper::permutation() {
   return perm;
 }
 
-Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
-                                 const bool keep_dims) {
-  // bitmap[i] indicates whether to reduce data along i-th axis.
-  gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
-  auto axis_vec = axis.flat<int32>();
+template <typename Tperm>
+Status SimplifyHelper(const Tensor& data, const Tensor& axis,
+                      gtl::InlinedVector<bool, 4>& bitmap) {
+  auto axis_vec = axis.flat<Tperm>();
   for (int64 i = 0; i < axis.NumElements(); ++i) {
-    int32 index = axis_vec(i);
+    Tperm index = axis_vec(i);
     if (index < -data.dims() || index >= data.dims()) {
       return errors::InvalidArgument("Invalid reduction dimension (", index,
                                      " for input with ", data.dims(),
@@ -72,7 +71,18 @@ Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
     index = (index + data.dims()) % data.dims();
     bitmap[index] = true;
   }
+  return Status::OK();
+}
 
+Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
+                                 const bool keep_dims) {
+  // bitmap[i] indicates whether to reduce data along i-th axis.
+  gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
+  if (axis.dtype() == DT_INT32) {
+    TF_RETURN_IF_ERROR(SimplifyHelper<int32>(data, axis, bitmap));
+  } else {
+    TF_RETURN_IF_ERROR(SimplifyHelper<int64>(data, axis, bitmap));
+  }
   // Output tensor's dim sizes.
   out_shape_.clear();
   for (int i = 0; i < data.dims(); ++i) {
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index 71af9d88dc1..9da992ccd18 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -25,6 +25,7 @@ limitations under the License.
 
 #include "third_party/eigen3/Eigen/Core"
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
 #include "tensorflow/core/framework/numeric_op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -42,7 +43,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
 typedef Eigen::GpuDevice GPUDevice;
 #ifdef TENSORFLOW_USE_SYCL
 typedef Eigen::SyclDevice SYCLDevice;
-#endif // TENSORFLOW_USE_SYCL
+#endif  // TENSORFLOW_USE_SYCL
 
 template <typename Device>
 struct Constants {
@@ -68,11 +69,13 @@ struct ConstantsBase {
   const Eigen::IndexList<Eigen::type2index<1>> kOne;
   const Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2>> kZeroTwo;
 };
-template<> struct Constants<CPUDevice> : ConstantsBase{};
+template <>
+struct Constants<CPUDevice> : ConstantsBase {};
 #ifdef TENSORFLOW_USE_SYCL
-template<> struct Constants<SYCLDevice> : ConstantsBase{};
-#endif // TENSORFLOW_USE_SYCL
-#endif // EIGEN_HAS_INDEX_LIST
+template <>
+struct Constants<SYCLDevice> : ConstantsBase {};
+#endif  // TENSORFLOW_USE_SYCL
+#endif  // EIGEN_HAS_INDEX_LIST
 
 class ReductionHelper {
  public:
@@ -131,12 +134,13 @@ class ReductionHelper {
 
 // For operations where the output is a reduction function along some
 // dimensions of the input.
-template <typename Device, class T, typename Reducer>
+template <typename Device, class T, typename Tperm, typename Reducer>
 class ReductionOp : public OpKernel {
  public:
   explicit ReductionOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     const DataType dt = DataTypeToEnum<T>::v();
-    OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
+    const DataType pt = DataTypeToEnum<Tperm>::v();
+    OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, pt}, {dt}));
 
     OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
   }
@@ -266,20 +270,19 @@ struct ReduceFunctorBase {
   }
 
   template <typename OUT_T>
-  static void FillIdentity(const Device& d, OUT_T out,
-                           const Reducer& reducer) {
+  static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer) {
     FillIdentityEigenImpl(d, out, reducer);
   }
 };
 
 template <typename Reducer>
 struct ReduceFunctor<CPUDevice, Reducer>
-        : ReduceFunctorBase<CPUDevice, Reducer>{};
+    : ReduceFunctorBase<CPUDevice, Reducer> {};
 #if TENSORFLOW_USE_SYCL
 template <typename Reducer>
 struct ReduceFunctor<SYCLDevice, Reducer>
-        : ReduceFunctorBase<SYCLDevice, Reducer>{};
-#endif // TENSORFLOW_USE_SYCL
+    : ReduceFunctorBase<SYCLDevice, Reducer> {};
+#endif  // TENSORFLOW_USE_SYCL
 
 }  // namespace functor
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc
index 4ca5c11a485..9cf953f4bfe 100644
--- a/tensorflow/core/kernels/reduction_ops_max.cc
+++ b/tensorflow/core/kernels/reduction_ops_max.cc
@@ -17,26 +17,39 @@ limitations under the License.
 
 namespace tensorflow {
 
-#define REGISTER_CPU_KERNELS(type)        \
-  REGISTER_KERNEL_BUILDER(                \
-      Name("Max")                         \
-          .Device(DEVICE_CPU)             \
-          .TypeConstraint<type>("T")      \
-          .TypeConstraint<int32>("Tidx"), \
-      ReductionOp<CPUDevice, type, Eigen::internal::MaxReducer<type>>);
+#define REGISTER_CPU_KERNELS(type)                                             \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Max")                                                              \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int32>("Tidx"),                                      \
+      ReductionOp<CPUDevice, type, int32, Eigen::internal::MaxReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Max")                                                              \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int64>("Tidx"),                                      \
+      ReductionOp<CPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
 #if GOOGLE_CUDA
 
-#define REGISTER_GPU_KERNELS(type)          \
-  REGISTER_KERNEL_BUILDER(                  \
-      Name("Max")                           \
-          .Device(DEVICE_GPU)               \
-          .TypeConstraint<type>("T")        \
-          .TypeConstraint<int32>("Tidx")    \
-          .HostMemory("reduction_indices"), \
-      ReductionOp<GPUDevice, type, Eigen::internal::MaxReducer<type>>);
+#define REGISTER_GPU_KERNELS(type)                                             \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Max")                                                              \
+          .Device(DEVICE_GPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int32>("Tidx")                                       \
+          .HostMemory("reduction_indices"),                                    \
+      ReductionOp<GPUDevice, type, int32, Eigen::internal::MaxReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Max")                                                              \
+          .Device(DEVICE_GPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int64>("Tidx")                                       \
+          .HostMemory("reduction_indices"),                                    \
+      ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
 REGISTER_GPU_KERNELS(float);
 REGISTER_GPU_KERNELS(double);
 REGISTER_GPU_KERNELS(int64);
@@ -52,21 +65,37 @@ REGISTER_KERNEL_BUILDER(
         .HostMemory("output")
         .TypeConstraint<int32>("T")
         .TypeConstraint<int32>("Tidx"),
-    ReductionOp<CPUDevice, int32, Eigen::internal::MaxReducer<int32>>);
+    ReductionOp<CPUDevice, int32, int32, Eigen::internal::MaxReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+    Name("Max")
+        .Device(DEVICE_GPU)
+        .HostMemory("reduction_indices")
+        .HostMemory("input")
+        .HostMemory("output")
+        .TypeConstraint<int32>("T")
+        .TypeConstraint<int64>("Tidx"),
+    ReductionOp<CPUDevice, int32, int64, Eigen::internal::MaxReducer<int32>>);
 
 #undef REGISTER_GPU_KERNELS
 
 #endif
 
 #ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNELS(type)         \
-  REGISTER_KERNEL_BUILDER(                  \
-      Name("Max")                           \
-          .Device(DEVICE_SYCL)              \
-          .TypeConstraint<type>("T")        \
-          .TypeConstraint<int32>("Tidx")    \
-          .HostMemory("reduction_indices"), \
-      ReductionOp<SYCLDevice, type, Eigen::internal::MaxReducer<type>>);
+#define REGISTER_SYCL_KERNELS(type)                                        \
+  REGISTER_KERNEL_BUILDER(Name("Max")                                      \
+                              .Device(DEVICE_SYCL)                         \
+                              .TypeConstraint<type>("T")                   \
+                              .TypeConstraint<int32>("Tidx")               \
+                              .HostMemory("reduction_indices"),            \
+                          ReductionOp<SYCLDevice, type, int32,             \
+                                      Eigen::internal::MaxReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(Name("Max")                                      \
+                              .Device(DEVICE_SYCL)                         \
+                              .TypeConstraint<type>("T")                   \
+                              .TypeConstraint<int64>("Tidx")               \
+                              .HostMemory("reduction_indices"),            \
+                          ReductionOp<SYCLDevice, type, int64,             \
+                                      Eigen::internal::MaxReducer<type>>);
 REGISTER_SYCL_KERNELS(float);
 REGISTER_SYCL_KERNELS(double);
 
@@ -78,8 +107,17 @@ REGISTER_KERNEL_BUILDER(
         .HostMemory("output")
         .TypeConstraint<int32>("T")
         .TypeConstraint<int32>("Tidx"),
-    ReductionOp<CPUDevice, int32, Eigen::internal::MaxReducer<int32>>);
+    ReductionOp<CPUDevice, int32, int32, Eigen::internal::MaxReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+    Name("Max")
+        .Device(DEVICE_SYCL)
+        .HostMemory("reduction_indices")
+        .HostMemory("input")
+        .HostMemory("output")
+        .TypeConstraint<int32>("T")
+        .TypeConstraint<int64>("Tidx"),
+    ReductionOp<CPUDevice, int32, int64, Eigen::internal::MaxReducer<int32>>);
 #undef REGISTER_SYCL_KERNELS
-#endif // TENSORFLOW_USE_SYCL
+#endif  // TENSORFLOW_USE_SYCL
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_mean.cc b/tensorflow/core/kernels/reduction_ops_mean.cc
index 5b01de8ddbc..f61589f913b 100644
--- a/tensorflow/core/kernels/reduction_ops_mean.cc
+++ b/tensorflow/core/kernels/reduction_ops_mean.cc
@@ -17,26 +17,39 @@ limitations under the License.
 
 namespace tensorflow {
 
-#define REGISTER_CPU_KERNELS(type)        \
-  REGISTER_KERNEL_BUILDER(                \
-      Name("Mean")                        \
-          .Device(DEVICE_CPU)             \
-          .TypeConstraint<type>("T")      \
-          .TypeConstraint<int32>("Tidx"), \
-      ReductionOp<CPUDevice, type, Eigen::internal::MeanReducer<type>>);
+#define REGISTER_CPU_KERNELS(type)                                          \
+  REGISTER_KERNEL_BUILDER(Name("Mean")                                      \
+                              .Device(DEVICE_CPU)                           \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int32>("Tidx"),               \
+                          ReductionOp<CPUDevice, type, int32,               \
+                                      Eigen::internal::MeanReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(Name("Mean")                                      \
+                              .Device(DEVICE_CPU)                           \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int64>("Tidx"),               \
+                          ReductionOp<CPUDevice, type, int64,               \
+                                      Eigen::internal::MeanReducer<type>>);
 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
 #if GOOGLE_CUDA
 
-#define REGISTER_GPU_KERNELS(type)          \
-  REGISTER_KERNEL_BUILDER(                  \
-      Name("Mean")                          \
-          .Device(DEVICE_GPU)               \
-          .TypeConstraint<type>("T")        \
-          .TypeConstraint<int32>("Tidx")    \
-          .HostMemory("reduction_indices"), \
-      ReductionOp<GPUDevice, type, Eigen::internal::MeanReducer<type>>);
+#define REGISTER_GPU_KERNELS(type)                                          \
+  REGISTER_KERNEL_BUILDER(Name("Mean")                                      \
+                              .Device(DEVICE_GPU)                           \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int32>("Tidx")                \
+                              .HostMemory("reduction_indices"),             \
+                          ReductionOp<GPUDevice, type, int32,               \
+                                      Eigen::internal::MeanReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(Name("Mean")                                      \
+                              .Device(DEVICE_GPU)                           \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int64>("Tidx")                \
+                              .HostMemory("reduction_indices"),             \
+                          ReductionOp<GPUDevice, type, int64,               \
+                                      Eigen::internal::MeanReducer<type>>);
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
 TF_CALL_complex64(REGISTER_GPU_KERNELS);
 TF_CALL_complex128(REGISTER_GPU_KERNELS);
@@ -45,17 +58,24 @@ TF_CALL_complex128(REGISTER_GPU_KERNELS);
 #endif
 
 #ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNELS(type)         \
-  REGISTER_KERNEL_BUILDER(                  \
-      Name("Mean")                          \
-          .Device(DEVICE_SYCL)              \
-          .TypeConstraint<type>("T")        \
-          .TypeConstraint<int32>("Tidx")    \
-          .HostMemory("reduction_indices"), \
-      ReductionOp<SYCLDevice, type, Eigen::internal::MeanReducer<type>>);
+#define REGISTER_SYCL_KERNELS(type)                                         \
+  REGISTER_KERNEL_BUILDER(Name("Mean")                                      \
+                              .Device(DEVICE_SYCL)                          \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int32>("Tidx")                \
+                              .HostMemory("reduction_indices"),             \
+                          ReductionOp<SYCLDevice, type, int32,              \
+                                      Eigen::internal::MeanReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(Name("Mean")                                      \
+                              .Device(DEVICE_SYCL)                          \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int64>("Tidx")                \
+                              .HostMemory("reduction_indices"),             \
+                          ReductionOp<SYCLDevice, type, int64,              \
+                                      Eigen::internal::MeanReducer<type>>);
 REGISTER_SYCL_KERNELS(float);
 REGISTER_SYCL_KERNELS(double);
 #undef REGISTER_SYCL_KERNELS
-#endif // TENSORFLOW_USE_SYCL
+#endif  // TENSORFLOW_USE_SYCL
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc
index 1e394bea41f..807ac0a4567 100644
--- a/tensorflow/core/kernels/reduction_ops_min.cc
+++ b/tensorflow/core/kernels/reduction_ops_min.cc
@@ -17,26 +17,39 @@ limitations under the License.
 
 namespace tensorflow {
 
-#define REGISTER_CPU_KERNELS(type)        \
-  REGISTER_KERNEL_BUILDER(                \
-      Name("Min")                         \
-          .Device(DEVICE_CPU)             \
-          .TypeConstraint<type>("T")      \
-          .TypeConstraint<int32>("Tidx"), \
-      ReductionOp<CPUDevice, type, Eigen::internal::MinReducer<type>>);
+#define REGISTER_CPU_KERNELS(type)                                             \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Min")                                                              \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int32>("Tidx"),                                      \
+      ReductionOp<CPUDevice, type, int32, Eigen::internal::MinReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Min")                                                              \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int64>("Tidx"),                                      \
+      ReductionOp<CPUDevice, type, int64, Eigen::internal::MinReducer<type>>);
 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
 #if GOOGLE_CUDA
 
-#define REGISTER_GPU_KERNELS(type)          \
-  REGISTER_KERNEL_BUILDER(                  \
-      Name("Min")                           \
-          .Device(DEVICE_GPU)               \
-          .TypeConstraint<type>("T")        \
-          .TypeConstraint<int32>("Tidx")    \
-          .HostMemory("reduction_indices"), \
-      ReductionOp<GPUDevice, type, Eigen::internal::MinReducer<type>>);
+#define REGISTER_GPU_KERNELS(type)                                             \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Min")                                                              \
+          .Device(DEVICE_GPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int32>("Tidx")                                       \
+          .HostMemory("reduction_indices"),                                    \
+      ReductionOp<GPUDevice, type, int32, Eigen::internal::MinReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Min")                                                              \
+          .Device(DEVICE_GPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int64>("Tidx")                                       \
+          .HostMemory("reduction_indices"),                                    \
+      ReductionOp<GPUDevice, type, int64, Eigen::internal::MinReducer<type>>);
 REGISTER_GPU_KERNELS(float);
 REGISTER_GPU_KERNELS(double);
 
@@ -51,21 +64,37 @@ REGISTER_KERNEL_BUILDER(
         .HostMemory("output")
         .TypeConstraint<int32>("T")
         .TypeConstraint<int32>("Tidx"),
-    ReductionOp<CPUDevice, int32, Eigen::internal::MinReducer<int32>>);
+    ReductionOp<CPUDevice, int32, int32, Eigen::internal::MinReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+    Name("Min")
+        .Device(DEVICE_GPU)
+        .HostMemory("reduction_indices")
+        .HostMemory("input")
+        .HostMemory("output")
+        .TypeConstraint<int32>("T")
+        .TypeConstraint<int64>("Tidx"),
+    ReductionOp<CPUDevice, int32, int64, Eigen::internal::MinReducer<int32>>);
 
 #undef REGISTER_GPU_KERNELS
 
 #endif
 
 #ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNELS(type)         \
-  REGISTER_KERNEL_BUILDER(                  \
-      Name("Min")                           \
-          .Device(DEVICE_SYCL)              \
-          .TypeConstraint<type>("T")        \
-          .TypeConstraint<int32>("Tidx")    \
-          .HostMemory("reduction_indices"), \
-      ReductionOp<SYCLDevice, type, Eigen::internal::MinReducer<type>>);
+#define REGISTER_SYCL_KERNELS(type)                                        \
+  REGISTER_KERNEL_BUILDER(Name("Min")                                      \
+                              .Device(DEVICE_SYCL)                         \
+                              .TypeConstraint<type>("T")                   \
+                              .TypeConstraint<int32>("Tidx")               \
+                              .HostMemory("reduction_indices"),            \
+                          ReductionOp<SYCLDevice, type, int32,             \
+                                      Eigen::internal::MinReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(Name("Min")                                      \
+                              .Device(DEVICE_SYCL)                         \
+                              .TypeConstraint<type>("T")                   \
+                              .TypeConstraint<int64>("Tidx")               \
+                              .HostMemory("reduction_indices"),            \
+                          ReductionOp<SYCLDevice, type, int64,             \
+                                      Eigen::internal::MinReducer<type>>);
 REGISTER_SYCL_KERNELS(float);
 REGISTER_SYCL_KERNELS(double);
 
@@ -77,8 +106,17 @@ REGISTER_KERNEL_BUILDER(
         .HostMemory("output")
         .TypeConstraint<int32>("T")
         .TypeConstraint<int32>("Tidx"),
-    ReductionOp<CPUDevice, int32, Eigen::internal::MinReducer<int32>>);
+    ReductionOp<CPUDevice, int32, int32, Eigen::internal::MinReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+    Name("Min")
+        .Device(DEVICE_SYCL)
+        .HostMemory("reduction_indices")
+        .HostMemory("input")
+        .HostMemory("output")
+        .TypeConstraint<int32>("T")
+        .TypeConstraint<int64>("Tidx"),
+    ReductionOp<CPUDevice, int32, int64, Eigen::internal::MinReducer<int32>>);
 #undef REGISTER_SYCL_KERNELS
-#endif // TENSORFLOW_USE_SYCL
+#endif  // TENSORFLOW_USE_SYCL
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_prod.cc b/tensorflow/core/kernels/reduction_ops_prod.cc
index 33f6ae6bae1..e9b23df7460 100644
--- a/tensorflow/core/kernels/reduction_ops_prod.cc
+++ b/tensorflow/core/kernels/reduction_ops_prod.cc
@@ -17,26 +17,39 @@ limitations under the License.
 
 namespace tensorflow {
 
-#define REGISTER_CPU_KERNELS(type)        \
-  REGISTER_KERNEL_BUILDER(                \
-      Name("Prod")                        \
-          .Device(DEVICE_CPU)             \
-          .TypeConstraint<type>("T")      \
-          .TypeConstraint<int32>("Tidx"), \
-      ReductionOp<CPUDevice, type, Eigen::internal::ProdReducer<type>>);
+#define REGISTER_CPU_KERNELS(type)                                          \
+  REGISTER_KERNEL_BUILDER(Name("Prod")                                      \
+                              .Device(DEVICE_CPU)                           \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int32>("Tidx"),               \
+                          ReductionOp<CPUDevice, type, int32,               \
+                                      Eigen::internal::ProdReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(Name("Prod")                                      \
+                              .Device(DEVICE_CPU)                           \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int64>("Tidx"),               \
+                          ReductionOp<CPUDevice, type, int64,               \
+                                      Eigen::internal::ProdReducer<type>>);
 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
 #if GOOGLE_CUDA
 
-#define REGISTER_GPU_KERNELS(type)          \
-  REGISTER_KERNEL_BUILDER(                  \
-      Name("Prod")                          \
-          .Device(DEVICE_GPU)               \
-          .TypeConstraint<type>("T")        \
-          .TypeConstraint<int32>("Tidx")    \
-          .HostMemory("reduction_indices"), \
-      ReductionOp<GPUDevice, type, Eigen::internal::ProdReducer<type>>);
+#define REGISTER_GPU_KERNELS(type)                                          \
+  REGISTER_KERNEL_BUILDER(Name("Prod")                                      \
+                              .Device(DEVICE_GPU)                           \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int32>("Tidx")                \
+                              .HostMemory("reduction_indices"),             \
+                          ReductionOp<GPUDevice, type, int32,               \
+                                      Eigen::internal::ProdReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(Name("Prod")                                      \
+                              .Device(DEVICE_GPU)                           \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int64>("Tidx")                \
+                              .HostMemory("reduction_indices"),             \
+                          ReductionOp<GPUDevice, type, int64,               \
+                                      Eigen::internal::ProdReducer<type>>);
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
 TF_CALL_int32(REGISTER_GPU_KERNELS);
 TF_CALL_complex64(REGISTER_GPU_KERNELS);
@@ -46,18 +59,25 @@ TF_CALL_complex128(REGISTER_GPU_KERNELS);
 #endif
 
 #ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNELS(type)         \
-  REGISTER_KERNEL_BUILDER(                  \
-      Name("Prod")                          \
-          .Device(DEVICE_SYCL)              \
-          .TypeConstraint<type>("T")        \
-          .TypeConstraint<int32>("Tidx")    \
-          .HostMemory("reduction_indices"), \
-      ReductionOp<SYCLDevice, type, Eigen::internal::ProdReducer<type>>);
+#define REGISTER_SYCL_KERNELS(type)                                         \
+  REGISTER_KERNEL_BUILDER(Name("Prod")                                      \
+                              .Device(DEVICE_SYCL)                          \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int32>("Tidx")                \
+                              .HostMemory("reduction_indices"),             \
+                          ReductionOp<SYCLDevice, type, int32,              \
+                                      Eigen::internal::ProdReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(Name("Prod")                                      \
+                              .Device(DEVICE_SYCL)                          \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int64>("Tidx")                \
+                              .HostMemory("reduction_indices"),             \
+                          ReductionOp<SYCLDevice, type, int64,              \
+                                      Eigen::internal::ProdReducer<type>>);
 REGISTER_SYCL_KERNELS(int32);
 REGISTER_SYCL_KERNELS(float);
 REGISTER_SYCL_KERNELS(double);
 #undef REGISTER_SYCL_KERNELS
-#endif // TENSORFLOW_USE_SYCL
+#endif  // TENSORFLOW_USE_SYCL
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index c1f4f3475af..5318d8c1339 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -17,26 +17,39 @@ limitations under the License.
 
 namespace tensorflow {
 
-#define REGISTER_CPU_KERNELS(type)        \
-  REGISTER_KERNEL_BUILDER(                \
-      Name("Sum")                         \
-          .Device(DEVICE_CPU)             \
-          .TypeConstraint<type>("T")      \
-          .TypeConstraint<int32>("Tidx"), \
-      ReductionOp<CPUDevice, type, Eigen::internal::SumReducer<type>>);
+#define REGISTER_CPU_KERNELS(type)                                             \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Sum")                                                              \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int32>("Tidx"),                                      \
+      ReductionOp<CPUDevice, type, int32, Eigen::internal::SumReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Sum")                                                              \
+          .Device(DEVICE_CPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int64>("Tidx"),                                      \
+      ReductionOp<CPUDevice, type, int64, Eigen::internal::SumReducer<type>>);
 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
 #undef REGISTER_CPU_KERNELS
 
 #if GOOGLE_CUDA
 
-#define REGISTER_GPU_KERNELS(type)          \
-  REGISTER_KERNEL_BUILDER(                  \
-      Name("Sum")                           \
-          .Device(DEVICE_GPU)               \
-          .TypeConstraint<type>("T")        \
-          .TypeConstraint<int32>("Tidx")    \
-          .HostMemory("reduction_indices"), \
-      ReductionOp<GPUDevice, type, Eigen::internal::SumReducer<type>>);
+#define REGISTER_GPU_KERNELS(type)                                             \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Sum")                                                              \
+          .Device(DEVICE_GPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int32>("Tidx")                                       \
+          .HostMemory("reduction_indices"),                                    \
+      ReductionOp<GPUDevice, type, int32, Eigen::internal::SumReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("Sum")                                                              \
+          .Device(DEVICE_GPU)                                                  \
+          .TypeConstraint<type>("T")                                           \
+          .TypeConstraint<int64>("Tidx")                                       \
+          .HostMemory("reduction_indices"),                                    \
+      ReductionOp<GPUDevice, type, int64, Eigen::internal::SumReducer<type>>);
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
 TF_CALL_complex64(REGISTER_GPU_KERNELS);
 TF_CALL_complex128(REGISTER_GPU_KERNELS);
@@ -53,19 +66,35 @@ REGISTER_KERNEL_BUILDER(
         .HostMemory("input")
         .HostMemory("output")
         .HostMemory("reduction_indices"),
-    ReductionOp<CPUDevice, int32, Eigen::internal::SumReducer<int32>>);
+    ReductionOp<CPUDevice, int32, int32, Eigen::internal::SumReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+    Name("Sum")
+        .Device(DEVICE_GPU)
+        .TypeConstraint<int32>("T")
+        .TypeConstraint<int64>("Tidx")
+        .HostMemory("input")
+        .HostMemory("output")
+        .HostMemory("reduction_indices"),
+    ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
 
 #endif
 
 #ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNELS(type)         \
-  REGISTER_KERNEL_BUILDER(                  \
-      Name("Sum")                           \
-          .Device(DEVICE_SYCL)              \
-          .TypeConstraint<type>("T")        \
-          .TypeConstraint<int32>("Tidx")    \
-          .HostMemory("reduction_indices"), \
-      ReductionOp<SYCLDevice, type, Eigen::internal::SumReducer<type>>);
+#define REGISTER_SYCL_KERNELS(type)                                        \
+  REGISTER_KERNEL_BUILDER(Name("Sum")                                      \
+                              .Device(DEVICE_SYCL)                         \
+                              .TypeConstraint<type>("T")                   \
+                              .TypeConstraint<int32>("Tidx")               \
+                              .HostMemory("reduction_indices"),            \
+                          ReductionOp<SYCLDevice, type, int32,             \
+                                      Eigen::internal::SumReducer<type>>); \
+  REGISTER_KERNEL_BUILDER(Name("Sum")                                      \
+                              .Device(DEVICE_SYCL)                         \
+                              .TypeConstraint<type>("T")                   \
+                              .TypeConstraint<int64>("Tidx")               \
+                              .HostMemory("reduction_indices"),            \
+                          ReductionOp<SYCLDevice, type, int64,             \
+                                      Eigen::internal::SumReducer<type>>);
 REGISTER_SYCL_KERNELS(float);
 REGISTER_SYCL_KERNELS(double);
 
@@ -77,8 +106,17 @@ REGISTER_KERNEL_BUILDER(
         .HostMemory("input")
         .HostMemory("output")
         .HostMemory("reduction_indices"),
-    ReductionOp<CPUDevice, int32, Eigen::internal::SumReducer<int32>>);
+    ReductionOp<CPUDevice, int32, int32, Eigen::internal::SumReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+    Name("Sum")
+        .Device(DEVICE_SYCL)
+        .TypeConstraint<int32>("T")
+        .TypeConstraint<int64>("Tidx")
+        .HostMemory("input")
+        .HostMemory("output")
+        .HostMemory("reduction_indices"),
+    ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
 #undef REGISTER_SYCL_KERNELS
-#endif // TENSORFLOW_USE_SYCL
+#endif  // TENSORFLOW_USE_SYCL
 
 }  // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index c794351fe99..2dc65b13849 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -163,6 +163,13 @@ class SumReductionTest(BaseReductionTest):
       reduction_axes = tuple(reduction_axes)
     return np.sum(x, axis=reduction_axes, keepdims=keep_dims)
 
+  def testAxesType(self):
+    for dtype in [dtypes.int64, dtypes.int32]:
+      with self.test_session(use_gpu=True) as sess:
+        v = math_ops.reduce_sum([0, 0], constant_op.constant(0, dtype=dtype))
+        tf_v = sess.run(v)
+      self.assertAllEqual(tf_v, 0)
+
   def testInfinity(self):
     for dtype in [np.float32, np.float64]:
       for special_value_x in [-np.inf, np.inf]:
@@ -193,6 +200,7 @@ class SumReductionTest(BaseReductionTest):
       tf_out_mean = sess.run(tf_mean)
     self.assertAllClose(tf_out_mean, 1.)
 
+
   def testFloat32(self):
     for rank in range(1, _MAX_RANK + 1):
       np_arr = self._makeIncremental((2,) * rank, dtypes.float32)
@@ -369,6 +377,13 @@ class MeanReductionTest(BaseReductionTest):
       return np_sum // count
     return np_sum / count
 
+  def testAxesType(self):
+    for dtype in [dtypes.int64, dtypes.int32]:
+      with self.test_session(use_gpu=True) as sess:
+        v = math_ops.reduce_mean([0, 0], constant_op.constant(0, dtype=dtype))
+        tf_v = sess.run(v)
+      self.assertAllEqual(tf_v, 0)
+
   def testInfinity(self):
     for dtype in [np.float32, np.float64]:
       for special_value_x in [-np.inf, np.inf]:
@@ -435,6 +450,13 @@ class ProdReductionTest(BaseReductionTest):
       reduction_axes = tuple(reduction_axes)
     return np.prod(x, axis=reduction_axes, keepdims=keep_dims)
 
+  def testAxesType(self):
+    for dtype in [dtypes.int64, dtypes.int32]:
+      with self.test_session(use_gpu=True) as sess:
+        v = math_ops.reduce_prod([0, 0], constant_op.constant(0, dtype=dtype))
+        tf_v = sess.run(v)
+      self.assertAllEqual(tf_v, 0)
+
   def testInfinity(self):
     for dtype in [np.float32, np.float64]:
       for special_value_x in [-np.inf, np.inf]:
@@ -531,6 +553,13 @@ class MinReductionTest(test.TestCase):
     self._compare(x, reduction_axes, True, use_gpu=True)
     self._compare(x, reduction_axes, True, use_gpu=False)
 
+  def testAxesType(self):
+    for dtype in [dtypes.int64, dtypes.int32]:
+      with self.test_session(use_gpu=True) as sess:
+        v = math_ops.reduce_min([0, 0], constant_op.constant(0, dtype=dtype))
+        tf_v = sess.run(v)
+      self.assertAllEqual(tf_v, 0)
+
   def testInfinity(self):
     for dtype in [np.float32, np.float64]:
       for special_value_x in [-np.inf, np.inf]:
@@ -637,6 +666,13 @@ class MaxReductionTest(test.TestCase):
     self._compare(x, reduction_axes, True, use_gpu=True)
     self._compare(x, reduction_axes, True, use_gpu=False)
 
+  def testAxesType(self):
+    for dtype in [dtypes.int64, dtypes.int32]:
+      with self.test_session(use_gpu=True) as sess:
+        v = math_ops.reduce_max([0, 0], constant_op.constant(0, dtype=dtype))
+        tf_v = sess.run(v)
+      self.assertAllEqual(tf_v, 0)
+
   def testInfinity(self):
     for dtype in [np.float32, np.float64]:
       for special_value_x in [-np.inf, np.inf]:
@@ -757,6 +793,14 @@ class AllReductionTest(test.TestCase):
     self._compare(x, reduction_axes, True, use_gpu=True)
     self._compare(x, reduction_axes, True, use_gpu=False)
 
+  def testAxesType(self):
+    for dtype in [dtypes.int64, dtypes.int32]:
+      with self.test_session(use_gpu=True) as sess:
+        v = math_ops.reduce_all([True, True],
+                                constant_op.constant(0, dtype=dtype))
+        tf_v = sess.run(v)
+      self.assertAllEqual(tf_v, True)
+
   def testAll3D(self):
     # Create a 3D array of bools and reduce across all possible
     # dimensions
@@ -798,6 +842,14 @@ class AnyReductionTest(test.TestCase):
     self._compare(x, reduction_axes, True, use_gpu=True)
     self._compare(x, reduction_axes, True, use_gpu=False)
 
+  def testAxesType(self):
+    for dtype in [dtypes.int64, dtypes.int32]:
+      with self.test_session(use_gpu=True) as sess:
+        v = math_ops.reduce_any([True, True],
+                                constant_op.constant(0, dtype=dtype))
+        tf_v = sess.run(v)
+      self.assertAllEqual(tf_v, True)
+
   def testAll3D(self):
     # Create a 3D array of bools and reduce across all possible
     # dimensions