diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc
index 44bcab497ee..6fab5f1f5ad 100644
--- a/tensorflow/core/kernels/sparse_xent_op.cc
+++ b/tensorflow/core/kernels/sparse_xent_op.cc
@@ -18,7 +18,6 @@ limitations under the License.
 #define EIGEN_USE_THREADS
 
 #include "tensorflow/core/kernels/sparse_xent_op.h"
-
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -123,8 +122,6 @@ REGISTER(CPU, float, int32)
 REGISTER(CPU, float, int64)
 REGISTER(CPU, double, int32)
 REGISTER(CPU, double, int64)
-REGISTER(CPU, bfloat16, int32)
-REGISTER(CPU, bfloat16, int64)
 REGISTER(CPU, Eigen::half, int32)
 REGISTER(CPU, Eigen::half, int64)
 
diff --git a/tensorflow/core/kernels/sparse_xent_op_test.cc b/tensorflow/core/kernels/sparse_xent_op_test.cc
index f095f2e2cf7..85a5cd3befc 100644
--- a/tensorflow/core/kernels/sparse_xent_op_test.cc
+++ b/tensorflow/core/kernels/sparse_xent_op_test.cc
@@ -23,9 +23,9 @@ limitations under the License.
 
 namespace tensorflow {
 
-static Graph* SparseXent(int batch_size, int num_classes, DataType value_type) {
+static Graph* SparseXent(int batch_size, int num_classes) {
   Graph* g = new Graph(OpRegistry::Global());
-  Tensor logits(value_type, TensorShape({batch_size, num_classes}));
+  Tensor logits(DT_FLOAT, TensorShape({batch_size, num_classes}));
   logits.flat<float>().setRandom();
   Tensor labels(DT_INT64, TensorShape({batch_size}));
   std::random_device rd;
@@ -41,45 +41,44 @@ static Graph* SparseXent(int batch_size, int num_classes, DataType value_type) {
   return g;
 }
 
-#define BM_SparseXentDev(BATCH, CLASS, DEVICE, DTYPE)                        \
-  static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE(        \
+#define BM_SparseXentDev(BATCH, CLASS, DEVICE)                               \
+  static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE(                  \
       ::testing::benchmark::State& state) {                                  \
-    test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS, DTYPE),                \
+    test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS),                       \
                     /*old_benchmark_api*/ false)                             \
         .Run(state);                                                         \
     state.SetItemsProcessed(static_cast<int64>(state.iterations()) * BATCH * \
                             CLASS);                                          \
   }                                                                          \
-  BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE);
-
-#define BM_SPARSE_XENT_DEV_CPU(DTYPE)       \
-  BM_SparseXentDev(8, 1000000, cpu, DTYPE); \
-  BM_SparseXentDev(16, 10000, cpu, DTYPE);  \
-  BM_SparseXentDev(16, 100000, cpu, DTYPE); \
-  BM_SparseXentDev(32, 10000, cpu, DTYPE);  \
-  BM_SparseXentDev(32, 100000, cpu, DTYPE); \
-  BM_SparseXentDev(64, 10000, cpu, DTYPE);  \
-  BM_SparseXentDev(64, 100000, cpu, DTYPE);
-
-// CPU
-BM_SPARSE_XENT_DEV_CPU(DT_FLOAT);
-BM_SPARSE_XENT_DEV_CPU(DT_BFLOAT16);
+  BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE);
 
 /// The representative tests for ptb_word on GPU
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-BM_SparseXentDev(8, 1000000, gpu, DT_FLOAT);
+BM_SparseXentDev(8, 1000000, gpu);
 
-BM_SparseXentDev(16, 10000, gpu, DT_FLOAT);
-BM_SparseXentDev(16, 30000, gpu, DT_FLOAT);
-BM_SparseXentDev(16, 100000, gpu, DT_FLOAT);
+BM_SparseXentDev(16, 10000, gpu);
+BM_SparseXentDev(16, 30000, gpu);
+BM_SparseXentDev(16, 100000, gpu);
 
-BM_SparseXentDev(32, 10000, gpu, DT_FLOAT);
-BM_SparseXentDev(32, 30000, gpu, DT_FLOAT);
-BM_SparseXentDev(32, 100000, gpu, DT_FLOAT);
+BM_SparseXentDev(32, 10000, gpu);
+BM_SparseXentDev(32, 30000, gpu);
+BM_SparseXentDev(32, 100000, gpu);
 
-BM_SparseXentDev(64, 10000, gpu, DT_FLOAT);
-BM_SparseXentDev(64, 30000, gpu, DT_FLOAT);
-BM_SparseXentDev(64, 100000, gpu, DT_FLOAT);
+BM_SparseXentDev(64, 10000, gpu);
+BM_SparseXentDev(64, 30000, gpu);
+BM_SparseXentDev(64, 100000, gpu);
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
+// CPU
+BM_SparseXentDev(8, 1000000, cpu);
+
+BM_SparseXentDev(16, 10000, cpu);
+BM_SparseXentDev(16, 100000, cpu);
+
+BM_SparseXentDev(32, 10000, cpu);
+BM_SparseXentDev(32, 100000, cpu);
+
+BM_SparseXentDev(64, 10000, cpu);
+BM_SparseXentDev(64, 100000, cpu);
+
 }  // end namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index 99f70c16999..c53f196ecb9 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -182,23 +182,6 @@ class SparseXentTest(test.TestCase):
           np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
           np.array([0, 3]).astype(label_dtype))
 
-  def testBfloat16(self):
-    for label_dtype in np.int32, np.int64:
-      np_features = np.array([[1., 1., 1., 1.], [1., 2., 3.,
-                                                 4.]]).astype(np.float32)
-      np_labels = np.array([0, 3]).astype(label_dtype)
-      np_loss, np_backprop = self._npXent(np_features, np_labels)
-
-      np_features_bf16 = math_ops.cast(np_features, dtypes.bfloat16)
-      np_loss_bf16 = math_ops.cast(np_loss, dtypes.bfloat16)
-      np_backprop_bf16 = math_ops.cast(np_backprop, dtypes.bfloat16)
-      with self.cached_session(use_gpu=False):
-        loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
-            np_features_bf16, np_labels)
-        tf_loss, tf_backprop = self.evaluate([loss, backprop])
-      self.assertAllCloseAccordingToType(np_loss_bf16, tf_loss)
-      self.assertAllCloseAccordingToType(np_backprop_bf16, tf_backprop)
-
   def testHalf(self):
     for label_dtype in np.int32, np.int64:
       self._testXent(