diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 9e2ef964a1f..7ff01be3cb4 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -88,6 +88,38 @@ class XlaSortOpTest(xla_test.XLATestCase):
               topk, [x.astype(dtype)],
               expected=[x[indices].astype(dtype), indices])
 
+  def testTopK2D(self):
+    # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+    if self.device in ["XLA_CPU", "XLA_GPU"]:
+      return
+
+    supported_types = set(
+        [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
+    for dtype in supported_types.intersection(self.numeric_types):
+      # Use small input size for bfloat16. Otherwise, we'll get duplicate values
+      # after conversion to bfloat16, so the possible resulting index array is
+      # no longer unique.
+      if dtype == dtypes.bfloat16.as_numpy_dtype:
+        array_size = 10
+        k_options = [0, 1, 2, 10]
+      else:
+        array_size = 200 * 1000
+        k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
+      batch = 16
+      for x in [np.arange(batch * array_size)]:
+        np.random.shuffle(x)
+        x = np.reshape(x, [batch, array_size])
+        for k in k_options:
+          indices = x.argsort(axis=1)[::, -1:-k - 1:-1]
+          expected = np.sort(x, axis=1)[::, -1:-k - 1:-1]
+
+          def topk(v, k=k):
+            return nn_ops.top_k(v, k=k, sorted=True)
+
+          self._assertOpOutputMatchesExpected(
+              topk, [x.astype(dtype)],
+              expected=[expected.astype(dtype), indices])
+
   def testTopKZeros(self):
     """Tests that positive and negative zeros sort correctly."""
     # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index 1ddcb08c8e1..82d4a69777b 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -41,33 +41,35 @@ class TopKOp : public XlaOpKernel {
     OP_REQUIRES(context, input_shape.dims() >= 1,
                 errors::InvalidArgument("input must be >= 1-D, got shape ",
                                         input_shape.DebugString()));
+    int last_dim = input_shape.dims() - 1;
+    int last_dim_size = input_shape.dim_size(last_dim);
     OP_REQUIRES(
-        context, input_shape.dim_size(input_shape.dims() - 1) >= k,
+        context, last_dim_size >= k,
         errors::InvalidArgument("input must have at least k columns. Had ",
-                                input_shape.dim_size(input_shape.dims() - 1),
-                                ", needed ", k));
-
-    OP_REQUIRES(
-        context, input_shape.dims() == 1,
-        errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ",
-                              input_shape.DebugString()));
+                                last_dim_size, ", needed ", k));
 
     xla::XlaBuilder* const b = context->builder();
-    if (input_shape.dim_size(0) < k) {
-      k = input_shape.dim_size(0);
+    if (last_dim_size < k) {
+      k = last_dim_size;
     }
     const xla::XlaOp input = context->Input(0);
-    xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0));
-    xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32);
+
+    xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size);
+    auto input_dims = input_shape.dim_sizes();
+    std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1);
+    xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims);
+    xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32);
+
+    std::vector<int64> start_indices(input_shape.dims(), 0);
+    std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
+    limit_indices[last_dim] = k;
+    std::vector<int64> strides(input_shape.dims(), 1);
+
     xla::XlaOp values =
-        xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0),
-                            /*start_indices=*/{0},
-                            /*limit_indices=*/{k},
-                            /*strides=*/{1}));
+        xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices,
+                            limit_indices, strides));
     xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
-                                    /*start_indices=*/{0},
-                                    /*limit_indices=*/{k},
-                                    /*strides=*/{1});
+                                    start_indices, limit_indices, strides);
     context->SetOutput(0, values);
     context->SetOutput(1, indices);
   }