diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index d50fdec7c63..838718aa1e3 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -129,42 +129,35 @@ class XlaSortOpTest(xla_test.XLATestCase):
 
   def testTopKZeros(self):
     """Tests that positive and negative zeros sort correctly."""
-    # Only bfloat16 is implemented.
-    bfloat16 = dtypes.bfloat16.as_numpy_dtype
-    if bfloat16 not in self.numeric_types:
-      return
-
-    with self.session() as sess:
-      p = array_ops.placeholder(dtypes.bfloat16)
-      with self.test_scope():
-        topk = nn_ops.top_k(p, k=4)
-      results = sess.run(
-          topk,
-          {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)})
-      self.assertAllEqual(
-          np.array([3., 0., 0., 0.], dtype=bfloat16), results[0])
-      self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
+    supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
+    for dtype in supported_types.intersection(self.numeric_types):
+      with self.session() as sess:
+        p = array_ops.placeholder(dtype)
+        with self.test_scope():
+          topk = nn_ops.top_k(p, k=4)
+        results = sess.run(
+            topk,
+            {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=dtype)})
+        self.assertAllEqual(np.array([3., 0., 0., 0.], dtype=dtype), results[0])
+        self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
 
   def testTopKInfinities(self):
     """Tests that positive and negative infinity sort correctly."""
-    # Only bfloat16 is implemented.
-    bfloat16 = dtypes.bfloat16.as_numpy_dtype
-    if bfloat16 not in self.numeric_types:
-      return
-
-    with self.session() as sess:
-      p = array_ops.placeholder(dtypes.bfloat16)
-      with self.test_scope():
-        topk = nn_ops.top_k(p, k=6)
-      results = sess.run(topk, {
-          p: np.array(
-              [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16)
-      })
-      self.assertAllEqual(
-          np.array(
-              [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")],
-              dtype=bfloat16), results[0])
-      self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
+    supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
+    for dtype in supported_types.intersection(self.numeric_types):
+      with self.session() as sess:
+        p = array_ops.placeholder(dtype)
+        with self.test_scope():
+          topk = nn_ops.top_k(p, k=6)
+        results = sess.run(topk, {
+            p:
+                np.array([1, 2, float("inf"), -float("inf"), -1, -2],
+                         dtype=dtype)
+        })
+        self.assertAllEqual(
+            np.array([float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")],
+                     dtype=dtype), results[0])
+        self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
 
   def testInTopK(self):
     supported_types = set([np.int32, np.int64])
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 782d08296f0..6eaf43902fe 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -49,6 +49,7 @@ filegroup(
         "runtime_single_threaded_conv2d.cc",
         "runtime_single_threaded_fft.cc",
         "runtime_single_threaded_matmul.cc",
+        "runtime_topk.cc",
     ],
     visibility = [":friends"],
 )
@@ -64,6 +65,7 @@ filegroup(
         "runtime_single_threaded_conv2d.h",
         "runtime_single_threaded_fft.h",
         "runtime_single_threaded_matmul.h",
+        "runtime_topk.h",
     ],
     visibility = [":friends"],
 )
@@ -134,6 +136,7 @@ cc_library(
         "//tensorflow/compiler/xla/service:copy_insertion",
         "//tensorflow/compiler/xla/service:hlo_casting_utils",
         "//tensorflow/compiler/xla/service:dump",
+        "//tensorflow/compiler/xla/service:topk_rewriter",
         "//tensorflow/compiler/xla/service:map_inliner",
         "//tensorflow/compiler/xla/service:rng_bit_generator_expander",
         "//tensorflow/compiler/xla/service:tree_reduction_rewriter",
@@ -230,6 +233,7 @@ cc_library(
         ":runtime_fft",
         ":runtime_fork_join",
         ":runtime_key_value_sort",
+        ":runtime_topk",
         ":runtime_matmul",
         ":runtime_matmul_mkl",
         ":runtime_single_threaded_conv2d",
@@ -759,6 +763,19 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "runtime_topk",
+    srcs = ["runtime_topk.cc"],
+    hdrs = ["runtime_topk.h"],
+    copts = runtime_copts(),
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core/platform:dynamic_annotations",
+        "//tensorflow/core/platform:macros",
+        "//tensorflow/core/platform:types",
+    ],
+)
+
 cc_library(
     name = "runtime_fork_join",
     srcs = ["runtime_fork_join.cc"],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 04d703fdd59..0826d7b8ce1 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -104,6 +104,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/slice_sinker.h"
 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
+#include "tensorflow/compiler/xla/service/topk_rewriter.h"
 #include "tensorflow/compiler/xla/service/transpose_folding.h"
 #include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h"
 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
@@ -320,6 +321,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
     pass.AddPass<HloConstantFolding>();
     pass.AddPass<ConditionalSimplifier>();
   }
+  pipeline.AddPass<TopkRewriter>([](const HloSortInstruction* sort, int64) {
+    return sort->operand(0)->shape().element_type() == F32;
+  });
   pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
   pipeline.AddPass<TransposeFolding>(
       [&](const HloInstruction& dot,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 2231ecfa1e8..5bee6049a5e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -117,6 +117,7 @@ extern const char* const kParallelForkJoinSymbolName =
     "__xla_cpu_runtime_ParallelForkJoin";
 extern const char* const kKeyValueSortSymbolName =
     "__xla_cpu_runtime_KeyValueSort";
+extern const char* const kTopKF32SymbolName = "__xla_cpu_runtime_TopKF32";
 extern const char* const kTracingStartSymbolName =
     "__xla_cpu_runtime_TracingStart";
 extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index ee75b97e4dc..eb24e0bc334 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -72,6 +72,7 @@ extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
 extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
 extern const char* const kParallelForkJoinSymbolName;
 extern const char* const kKeyValueSortSymbolName;
+extern const char* const kTopKF32SymbolName;
 extern const char* const kAllReduceSymbolName;
 extern const char* const kCollectivePermuteSymbolName;
 extern const char* const kReplicaIdSymbolName;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 278e6479e48..2688a7898af 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -2387,6 +2387,41 @@ Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
   return Status::OK();
 }
 
+Status IrEmitter::HandleTopK(HloInstruction* hlo) {
+  TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
+  const HloInstruction* input = hlo->operand(0);
+  int64 k = hlo->shape().tuple_shapes(0).dimensions(1);
+  TF_RET_CHECK(input->shape().element_type() == F32);
+  TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
+      hlo->shape().tuple_shapes(0).layout()));
+  TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
+      hlo->shape().tuple_shapes(1).layout()));
+  TF_RET_CHECK(
+      LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout()));
+
+  TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice,
+                      assignment_.GetUniqueSlice(hlo->operand(0), {}));
+  TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_values_slice,
+                      assignment_.GetUniqueSlice(hlo, {0}));
+  TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_indices_slice,
+                      assignment_.GetUniqueSlice(hlo, {1}));
+  llvm::Value* values_ptr =
+      EmitBufferPointer(values_slice, hlo->operand(0)->shape());
+  llvm::Value* out_values_ptr =
+      EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0));
+  llvm::Value* out_indices_ptr =
+      EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1));
+  EmitCallToFunc(runtime::kTopKF32SymbolName,
+                 {b_.getInt64(input->shape().dimensions(0)),
+                  b_.getInt64(input->shape().dimensions(1)), b_.getInt64(k),
+                  values_ptr, out_values_ptr, out_indices_ptr},
+                 b_.getVoidTy());
+
+  llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr},
+                     &b_);
+  return Status::OK();
+}
+
 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
   if (custom_call->custom_call_target() == "PadToStatic") {
     return HandlePadToStatic(custom_call);
@@ -2394,6 +2429,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
   if (custom_call->custom_call_target() == "SliceToDynamic") {
     return HandleSliceToDynamic(custom_call);
   }
+  if (custom_call->custom_call_target() == "TopK") {
+    return HandleTopK(custom_call);
+  }
   absl::Span<HloInstruction* const> operands(custom_call->operands());
   llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
   llvm::AllocaInst* operands_alloca =
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 3955deefbea..f136e3470e5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -190,6 +190,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
  private:
   Status HandleSliceToDynamic(HloInstruction* hlo);
   Status HandlePadToStatic(HloInstruction* hlo);
+  Status HandleTopK(HloInstruction* hlo);
   Status HandleAllReduceSingleReplica(HloInstruction* crs);
   Status HandleAllReduceMultipleReplica(HloInstruction* crs);
 
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_topk.cc b/tensorflow/compiler/xla/service/cpu/runtime_topk.cc
new file mode 100644
index 00000000000..5174a3329fb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_topk.cc
@@ -0,0 +1,76 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/runtime_topk.h"
+
+#include <algorithm>
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/macros.h"
+
+template <typename T>
+static void TopK(tensorflow::int64 batch_size, tensorflow::int64 input_size,
+                 tensorflow::int64 k, const T* values, T* out_values,
+                 tensorflow::int32* out_indices) {
+  // 'values' is managed by the JIT code, so msan can't tell they are
+  // initialized.
+  TF_ANNOTATE_MEMORY_IS_INITIALIZED(values,
+                                    input_size * batch_size * sizeof(T));
+
+  std::vector<tensorflow::int32> temp_indices(input_size);
+  for (tensorflow::int64 batch = 0; batch != batch_size; ++batch) {
+    std::iota(temp_indices.begin(), temp_indices.end(), 0);
+
+    const T* values_batch = values + batch * input_size;
+
+    auto convert_to_int = [](T value) {
+      tensorflow::uint32 x;
+      std::memcpy(&x, &value, sizeof(x));
+      return static_cast<tensorflow::int32>(x) < 0
+                 ? std::numeric_limits<tensorflow::int32>::max() - x
+                 : x;
+    };
+
+    auto kth_element = temp_indices.begin() + k;
+    std::partial_sort(temp_indices.begin(), kth_element, temp_indices.end(),
+                      [&](size_t i1, size_t i2) {
+                        // Do the comparison in integers to enforce a total
+                        // order of -NaN < -Inf < -0 < +0 < +Inf < +NaN.
+                        tensorflow::int32 v1 = convert_to_int(values_batch[i1]);
+                        tensorflow::int32 v2 = convert_to_int(values_batch[i2]);
+                        if (v1 == v2) {
+                          return i1 < i2;  // Stabilize sorting.
+                        }
+                        return v1 > v2;
+                      });
+
+    T* out_values_batch = out_values + batch * k;
+    tensorflow::int32* out_indices_batch = out_indices + batch * k;
+    std::copy(temp_indices.begin(), kth_element, out_indices_batch);
+    for (tensorflow::int64 i = 0; i < k; i++) {
+      out_values_batch[i] = values_batch[temp_indices[i]];
+    }
+  }
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TopKF32(
+    tensorflow::int64 batch_size, tensorflow::int64 input_size,
+    tensorflow::int64 k, const float* values, float* out_values,
+    tensorflow::int32* out_indices) {
+  TopK(batch_size, input_size, k, values, out_values, out_indices);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_topk.h b/tensorflow/compiler/xla/service/cpu/runtime_topk.h
new file mode 100644
index 00000000000..de69c0603e3
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_topk.h
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H
+
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+// Calculates `batch_size` topk operations with `input_size` inputs each. The
+// outputs are written to `out_values` and `out_indices`.
+extern void __xla_cpu_runtime_TopKF32(tensorflow::int64 batch_size,
+                                      tensorflow::int64 input_size,
+                                      tensorflow::int64 k, const float* values,
+                                      float* out_values,
+                                      tensorflow::int32* out_indices);
+}
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index 631c6985b03..28508bde4cd 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -44,6 +44,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_topk.h"
 #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -270,6 +271,7 @@ bool RegisterKnownJITSymbols() {
   REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
   REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
   REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort);
+  REGISTER_CPU_RUNTIME_SYMBOL(TopKF32);
   REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
   REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
 
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index d7c50dce3ca..527071d5f31 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -253,6 +253,22 @@ tf_cc_test(
     ],
 )
 
+tf_cc_test(
+    name = "cpu_topk_test",
+    srcs = ["cpu_topk_test.cc"],
+    deps = [
+        ":cpu_codegen_test",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/client/lib:sorting",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+        "//tensorflow/compiler/xla/service/cpu:test_header_helper",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
+)
+
 tf_cc_test(
     name = "cpu_vectorization_test",
     srcs = ["cpu_vectorization_test.cc"],
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc
new file mode 100644
index 00000000000..a4c74cfb8a2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc
@@ -0,0 +1,59 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/lib/sorting.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+
+using CpuTopKTest = CpuCodegenTest;
+
+TEST_F(CpuTopKTest, CallRuntime) {
+  XlaBuilder builder(TestName());
+  XlaOp input =
+      Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {5, 100}), "input");
+  TopK(input, 10);
+  TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, builder.Build());
+
+  TF_ASSERT_OK_AND_ASSIGN(ProgramShape program_shape,
+                          xla_computation.GetProgramShape());
+  HloModuleConfig config(program_shape);
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto module, HloModule::CreateFromProto(xla_computation.proto(), config));
+
+  constexpr char filecheck_pattern[] = R"(
+    CHECK: call void @__xla_cpu_runtime_TopKF32(i64 5, i64 100, i64 10,
+  )";
+
+  CpuAotCompilationOptions options{
+      /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost,
+      /*features=*/"",
+      /*entry_point_name=*/"entry",
+      /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+  CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+                                /*match_optimized_ir=*/true);
+}
+
+}  // namespace
+}  // namespace cpu
+}  // namespace xla