diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 6e76db02de0..62ae5ae78c6 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -208,7 +208,6 @@ filegroup(
         "//tensorflow/compiler/jit/kernels:all_files",
         "//tensorflow/compiler/jit/legacy_flags:all_files",
         "//tensorflow/compiler/jit/ops:all_files",
-        "//tensorflow/compiler/plugin/executor:all_files",
         "//tensorflow/compiler/tests:all_files",
         "//tensorflow/compiler/tf2xla:all_files",
         "//tensorflow/compiler/tf2xla/cc:all_files",
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 7ebd8422181..306e704415b 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -15,10 +15,7 @@ package_group(
 )
 
 package(
-    default_visibility = [
-        ":internal",
-        "//tensorflow/compiler/plugin/executor:__pkg__",
-    ],
+    default_visibility = [":internal"],
 )
 
 load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 97f3512a6c4..ed204b81821 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -2,7 +2,6 @@ licenses(["notice"])  # Apache 2.0
 
 package(
     default_visibility = [
-        "//tensorflow/compiler/plugin/executor:__pkg__",
         "//tensorflow/compiler/tf2xla:internal",
     ],
 )
diff --git a/tensorflow/compiler/plugin/executor/BUILD b/tensorflow/compiler/plugin/executor/BUILD
index 2e5875705f2..9bc706abdf6 100644
--- a/tensorflow/compiler/plugin/executor/BUILD
+++ b/tensorflow/compiler/plugin/executor/BUILD
@@ -11,11 +11,9 @@ cc_library(
         "*.h",
     ]),
     deps = [
-        "//tensorflow/compiler/jit:xla_device",
         "//tensorflow/compiler/jit:xla_jit_headers_lib",
-        "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/xla:xla_headers_lib",
-        "//tensorflow/compiler/xla/service",
+        "//tensorflow/compiler/xla/service:hlo_evaluator",
         "//third_party/eigen3",
         "@local_config_cuda//cuda:cuda_headers",
         "@protobuf//:protobuf_headers",
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 432b24756d2..044857d4222 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -175,11 +175,6 @@ tf_xla_py_test(
     name = "slice_ops_test",
     size = "small",
     srcs = ["slice_ops_test.py"],
-    # TODO(b/62962492): Test fails with assertion error.
-    tags = [
-        "manual",
-        "notap",
-    ],
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
@@ -461,11 +456,6 @@ cuda_py_test(
         "//tensorflow/python:math_ops",
         "//tensorflow/python:nn_ops",
     ],
-    # TODO(b/62961789): Test fails with SIGABRT
-    tags = [
-        "manual",
-        "notap",
-    ],
 )
 
 cc_library(
@@ -534,12 +524,8 @@ cuda_py_test(
 # --dump_graph_dir, and the config file was written by hand.
 #
 # Run the following to build a minimal benchmark of the computation on Android:
-# $ bazel build -c opt --cxxopt='-std=c++11' --linkopt='-lm' \
-#   --cpu=armeabi-v7a \
-#   --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
-#   --crosstool_top=//external:android/crosstool \
-#   //tensorflow/compiler/tests:lstm_layer_inference_benchmark
-
+# $ bazel build -c opt --config=android_arm \
+#       third_party/tensorflow/compiler/tests:lstm_layer_inference_benchmark
 #
 # Currently the resulting binary size is ~190KB
 tf_library(
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index f53f38f3cf8..ce8518267f3 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -260,7 +260,11 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None):
   For example:
 
   ```python
+<<<<<<< HEAD
   mapping_string = tf.constant(["emerson", "lake", "palmer"])
+=======
+  mapping_string = tf.constant(["emerson", "lake", "palmer")
+>>>>>>> 338a7ead4475d6b97b420d6d1c56ff66815e3e7b
   indices = tf.constant([1, 5], tf.int64)
   values = tf.contrib.lookup.index_to_string(
       indices, mapping=mapping_string, default_value="UNKNOWN")
diff --git a/tensorflow/core/kernels/sparse_reduce_sum_op.cc b/tensorflow/core/kernels/sparse_reduce_sum_op.cc
new file mode 100644
index 00000000000..074aab9f9e2
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_reduce_sum_op.cc
@@ -0,0 +1,305 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/sparse_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+// TODO(b/31496047): Fix non-standard include order.
+#include <numeric>  // clang-format off
+
+using tensorflow::sparse::SparseTensor;
+using tensorflow::gtl::ArraySlice;
+
+namespace tensorflow {
+
+struct ReduceDetails {
+  // The dimensions to call Reorder() with.
+  std::vector<int64> reorder_dims;
+
+  // The dimensions to call group() with after Reorder().
+  std::vector<int64> group_by_dims;
+
+  // The shape after reduction.
+  TensorShape reduced_shape;
+};
+
+// Compute common reduce parameters that'll be used for SparseTensor
+// reductions. Usage:
+// ReduceDetails reduction = SparseTensorReduceHelper(sp, axes, keep_dims);
+// sp.Reorder(reduction.reorder_dims);
+// for (const auto& g : sp.group(reduction.group_by_dims)) {
+//   ...
+// }
+// // Set output shape to reduction.reduced_shape.
+ReduceDetails SparseTensorReduceHelper(const SparseTensor &sp,
+                                       gtl::ArraySlice<int32> axes_slice,
+                                       bool keep_dims) {
+  ReduceDetails reduction;
+
+  std::vector<int32> reduction_axes(axes_slice.begin(), axes_slice.end());
+  int ndims = sp.dims();
+  for (int64 i = 0; i < reduction_axes.size(); ++i) {
+    reduction_axes[i] = (reduction_axes[i] + ndims) % ndims;
+  }
+  std::sort(reduction_axes.begin(), reduction_axes.end());
+
+  // (0) Calculate the grouping dimensions:
+  // group_by_dims == {0, .., NDIMS-1} \ reduction_axes.
+  std::vector<int64> perm(ndims);
+  std::iota(perm.begin(), perm.end(), 0);
+
+  // Requires perm and reduction_axes_ be sorted; group_by_dims will be
+  // sorted as well.
+  std::set_difference(
+      perm.begin(), perm.end(), reduction_axes.begin(), reduction_axes.end(),
+      std::inserter(reduction.group_by_dims, reduction.group_by_dims.begin()));
+
+  // Now append the rest of the axes (the complement of group_by_dims_);
+  // result is used by Reorder().
+  reduction.reorder_dims = reduction.group_by_dims;
+  std::set_difference(perm.begin(), perm.end(), reduction.group_by_dims.begin(),
+                      reduction.group_by_dims.end(),
+                      std::back_inserter(reduction.reorder_dims));
+
+  // (1) Calculate the shape after reduction.
+  auto sp_shape = sp.shape();
+  std::vector<int64> out_dim_sizes;
+  if (keep_dims) {
+    out_dim_sizes.reserve(ndims);
+    auto beg = reduction.group_by_dims.begin();
+    auto end = reduction.group_by_dims.end();
+    for (int d = 0; d < ndims; ++d) {
+      if (std::find(beg, end, d) == end) {
+        out_dim_sizes.push_back(1);  // A reduced axis.
+      } else {
+        out_dim_sizes.push_back(sp_shape[d]);
+      }
+    }
+  } else {
+    out_dim_sizes = sp.PickDims(reduction.group_by_dims);
+  }
+
+  reduction.reduced_shape = TensorShape(out_dim_sizes);
+  return reduction;
+}
+
+Status ValidateInputs(const Tensor *shape_t, const Tensor *reduction_axes_t) {
+  // indices and values are validated in SparseTensor ctor.
+  if (!TensorShapeUtils::IsVector(shape_t->shape())) {
+    return errors::InvalidArgument(
+        "Expected input_shape to be a vector; got shape: ",
+        shape_t->shape().DebugString());
+  }
+  if (!TensorShapeUtils::IsScalar(reduction_axes_t->shape()) &&
+      !TensorShapeUtils::IsVector(reduction_axes_t->shape())) {
+    return errors::InvalidArgument(
+        "Expected reduction_axes to be a scalar or a vector; got shape: ",
+        reduction_axes_t->shape().DebugString());
+  }
+
+  const auto reduction_axes_flat = reduction_axes_t->flat<int32>();
+  for (int64 i = 0; i < reduction_axes_flat.size(); i++) {
+    int32 axis = reduction_axes_flat(i);
+    if (axis < -shape_t->NumElements() || axis >= shape_t->NumElements()) {
+      return errors::InvalidArgument("Invalid reduction dimension ", axis,
+                                     ", for input with ",
+                                     shape_t->NumElements(), " dimensions.");
+    }
+  }
+
+  return Status::OK();
+}
+
+template <typename T>
+class SparseReduceSumOp : public OpKernel {
+ public:
+  explicit SparseReduceSumOp(OpKernelConstruction *ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+  }
+
+  void Compute(OpKernelContext *ctx) override {
+    const Tensor *indices_t, *values_t, *shape_t, *reduction_axes_t;
+    OP_REQUIRES_OK(ctx, ctx->input("input_indices", &indices_t));
+    OP_REQUIRES_OK(ctx, ctx->input("input_values", &values_t));
+    OP_REQUIRES_OK(ctx, ctx->input("input_shape", &shape_t));
+    OP_REQUIRES_OK(ctx, ctx->input("reduction_axes", &reduction_axes_t));
+
+    OP_REQUIRES_OK(ctx, ValidateInputs(shape_t, reduction_axes_t));
+
+    // TODO(zongheng): we will call Reorder() below, which will modify
+    // in-place the underlying indices and values buffers.  To avoid
+    // surprises of this kernel being stateful, we work around the above by
+    // making deep copies here.  Remove this if/when we change Reorder()'s
+    // semantics.
+    const auto shape_vec = shape_t->vec<int64>();
+    SparseTensor sp(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
+                    TensorShape(shape_vec));
+    ReduceDetails reduction = SparseTensorReduceHelper(
+        sp, reduction_axes_t->flat<int32>(), keep_dims_);
+
+    Tensor *out_values;
+    OP_REQUIRES_OK(
+        ctx, ctx->allocate_output(0, reduction.reduced_shape, &out_values));
+    auto out_flat = out_values->flat<T>();
+    out_flat.setZero();
+
+    Tensor tmp_group_sum;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+                                           TensorShape({}), &tmp_group_sum));
+    auto group_sum = tmp_group_sum.scalar<T>();
+
+    // Compute strides, and use it to convert coords to flat index.  The
+    // coordinates returned by .group() have the same ndims as group_by_dims.
+    gtl::InlinedVector<int64, 8> output_strides(reduction.group_by_dims.size());
+    if (!output_strides.empty()) {  // Do this iff we don't reduce all.
+      output_strides.back() = 1;
+      for (int d = output_strides.size() - 2; d >= 0; --d) {
+        output_strides[d] =
+            output_strides[d + 1] * shape_vec(reduction.group_by_dims[d + 1]);
+      }
+    }
+
+    auto CoordinatesToFlatIndex = [](ArraySlice<int64> coords,
+                                     ArraySlice<int64> strides) {
+      if (strides.empty()) {  // Reduce all.
+        return 0LL;
+      }
+      CHECK_EQ(coords.size(), strides.size());
+      int64 idx = 0;
+      for (int i = 0; i < coords.size(); ++i) {
+        idx += coords[i] * strides[i];
+      }
+      return idx;
+    };
+
+    // Each group maps one-on-one onto a value in the reduced tensor.
+    // g.group() provides the coordinates of a particular reduced value.
+    sp.Reorder<T>(reduction.reorder_dims);
+    for (const auto &g : sp.group(reduction.group_by_dims)) {
+      group_sum.device(ctx->eigen_cpu_device()) = g.template values<T>().sum();
+      const int64 idx = CoordinatesToFlatIndex(g.group(), output_strides);
+      out_flat(idx) = group_sum();
+      VLOG(2) << "coords: " << str_util::Join(g.group(), ",")
+              << "; idx: " << idx << "; group sum: " << group_sum();
+    }
+  }
+
+ private:
+  // True if the number of dimensions should be maintained.
+  bool keep_dims_;
+};
+
+#define REGISTER_KERNELS(T)                                              \
+  REGISTER_KERNEL_BUILDER(                                               \
+      Name("SparseReduceSum").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      SparseReduceSumOp<T>)
+TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+template <typename T>
+class SparseReduceSumSparseOp : public OpKernel {
+ public:
+  explicit SparseReduceSumSparseOp(OpKernelConstruction *ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+  }
+
+  void Compute(OpKernelContext *ctx) override {
+    const Tensor *indices_t, *values_t, *shape_t, *reduction_axes_t;
+    OP_REQUIRES_OK(ctx, ctx->input("input_indices", &indices_t));
+    OP_REQUIRES_OK(ctx, ctx->input("input_values", &values_t));
+    OP_REQUIRES_OK(ctx, ctx->input("input_shape", &shape_t));
+    OP_REQUIRES_OK(ctx, ctx->input("reduction_axes", &reduction_axes_t));
+
+    OP_REQUIRES_OK(ctx, ValidateInputs(shape_t, reduction_axes_t));
+
+    SparseTensor sp(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
+                    TensorShape(shape_t->vec<int64>()));
+    ReduceDetails reduction = SparseTensorReduceHelper(
+        sp, reduction_axes_t->flat<int32>(), keep_dims_);
+
+    sp.Reorder<T>(reduction.reorder_dims);
+    // Count nnzs in the output SparseTensor.
+    int64 nnz = 0;
+    auto iter = sp.group(reduction.group_by_dims);
+    for (auto it = iter.begin(); it != iter.end(); ++it) {
+      nnz++;
+    }
+
+    Tensor *out_indices_t;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_output(
+                       0, TensorShape({nnz, reduction.reduced_shape.dims()}),
+                       &out_indices_t));
+    typename TTypes<int64>::Matrix out_indices_mat =
+        out_indices_t->matrix<int64>();
+    // For keep_dims. We don't explicitly set dim fields for reduced dims below.
+    out_indices_mat.setZero();
+
+    Tensor *out_values_t;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_output(1, TensorShape({nnz}), &out_values_t));
+    auto out_flat = out_values_t->flat<T>();
+
+    Tensor tmp_group_sum;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+                                           TensorShape({}), &tmp_group_sum));
+    auto group_sum = tmp_group_sum.scalar<T>();
+    int64 i = 0;
+    for (const auto &g : sp.group(reduction.group_by_dims)) {
+      group_sum.device(ctx->eigen_cpu_device()) = g.template values<T>().sum();
+      std::vector<int64> group = g.group();
+      for (int64 j = 0; j < group.size(); j++) {
+        if (keep_dims_) {
+          out_indices_mat(i, reduction.group_by_dims[j]) = group[j];
+        } else {
+          out_indices_mat(i, j) = group[j];
+        }
+      }
+      out_flat(i) = group_sum();
+      i++;
+      VLOG(2) << "coords: " << str_util::Join(g.group(), ",")
+              << "; group sum: " << group_sum();
+    }
+
+    Tensor *out_shape_t;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(
+                            2, TensorShape({reduction.reduced_shape.dims()}),
+                            &out_shape_t));
+    auto out_shape_flat = out_shape_t->flat<int64>();
+    auto out_dim_sizes = reduction.reduced_shape.dim_sizes();
+    std::copy(out_dim_sizes.begin(), out_dim_sizes.end(), &out_shape_flat(0));
+  }
+
+ private:
+  // True if the number of dimensions should be maintained.
+  bool keep_dims_;
+};
+
+#define REGISTER_KERNELS(T)                                                    \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("SparseReduceSumSparse").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      SparseReduceSumSparseOp<T>)
+TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+}  // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 936348e01d2..22b18b9cde0 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3056,7 +3056,6 @@ py_test(
     srcs = ["client/session_clusterspec_prop_test.py"],
     srcs_version = "PY2AND3",
     tags = [
-        "no_gpu",
         "no_pip_gpu",
     ],
     deps = [
@@ -3081,7 +3080,6 @@ py_test(
     srcs = ["client/session_list_devices_test.py"],
     srcs_version = "PY2AND3",
     tags = [
-        "no_gpu",
         "no_pip_gpu",
     ],
     deps = [
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index 66643622260..14eb2cba68a 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import numpy as np
+import unittest
 
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -605,6 +606,7 @@ class SparseReduceTest(test_util.TensorFlowTestCase):
     self._compare(sp_t, reduction_axes, ndims, True, False)
     self._compare(sp_t, reduction_axes, ndims, True, True)
 
+  @unittest.skipIf(np.__version__ == "1.13.0", "numpy 1.13 bug")
   def testSimpleAndRandomInputs(self):
     if np.__version__ == "1.13.0":
       self.skipTest("numpy 1.13.0 bug")
@@ -644,6 +646,7 @@ class SparseReduceTest(test_util.TensorFlowTestCase):
       with self.assertRaisesOpError("Invalid reduction dimension 2"):
         sparse_ops.sparse_reduce_max(sp_t, 2).eval()
 
+  @unittest.skipIf(np.__version__ == "1.13.0", "numpy 1.13 bug")
   def testGradient(self):
     if np.__version__ == "1.13.0":
       self.skipTest("numpy 1.13.0 bug")