diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index e116ba9082f..e1db701bcd5 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -417,7 +417,7 @@ distribute_py_test(
     srcs = ["keras_embedding_model_correctness_test.py"],
     full_precision = True,
     main = "keras_embedding_model_correctness_test.py",
-    shard_count = 4,
+    shard_count = 8,
     tags = [
         "multi_and_single_gpu",
         "no_windows_gpu",