From 19bb9fcff5663115b318fefc972b59281c3bad62 Mon Sep 17 00:00:00 2001
From: Priya Gupta <priyag@google.com>
Date: Fri, 18 Sep 2020 13:18:47 -0700
Subject: [PATCH] tf.distribute: Add MultiWorkerMirroredStrategy to
 keras_image_correctness_test.

While I am here, enable some of the partial batch tests with TPUStrategy + eager, those were previously only tested for graph mode.

PiperOrigin-RevId: 332513819
Change-Id: I32a43191550b1ec978d86ad470fdfdd63d81083f
---
 .../distribute/keras_correctness_test_base.py | 11 +++++---
 .../distribute/keras_dnn_correctness_test.py  |  5 ++--
 .../keras_image_model_correctness_test.py     | 25 +++++++++++--------
 ...as_stateful_lstm_model_correctness_test.py |  3 ++-
 4 files changed, 28 insertions(+), 16 deletions(-)

diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py
index 8f019cdb7c9..2639b6a78b8 100644
--- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py
+++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py
@@ -74,11 +74,16 @@ def graph_mode_test_configuration():
 
 def all_strategy_and_input_config_combinations():
   return (combinations.times(
-      combinations.combine(
-          distribution=all_strategies),
+      combinations.combine(distribution=all_strategies),
       eager_mode_test_configuration() + graph_mode_test_configuration()))
 
 
+def all_strategy_and_input_config_combinations_eager():
+  return (combinations.times(
+      combinations.combine(distribution=all_strategies),
+      eager_mode_test_configuration()))
+
+
 def strategy_minus_tpu_and_input_config_combinations_eager():
   return (combinations.times(
       combinations.combine(
@@ -114,7 +119,7 @@ def test_combinations_for_embedding_model():
           (eager_mode_test_configuration())))
 
 
-def test_combinations_with_tpu_strategies():
+def test_combinations_with_tpu_strategies_graph():
   tpu_strategies = [
       strategy_combinations.tpu_strategy,
   ]
diff --git a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
index 621b7feadf7..e6581a82692 100644
--- a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
@@ -110,7 +110,8 @@ class TestDistributionStrategyDnnCorrectness(
     self.run_correctness_test(distribution, use_numpy, use_validation_data)
 
   @ds_combinations.generate(
-      keras_correctness_test_base.test_combinations_with_tpu_strategies() +
+      keras_correctness_test_base
+      .test_combinations_with_tpu_strategies_graph() +
       keras_correctness_test_base.multi_worker_mirrored_eager())
   def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
                                                         use_numpy,
@@ -309,7 +310,7 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
         self.run_dynamic_lr_test(distribution)
 
   @ds_combinations.generate(
-      keras_correctness_test_base.test_combinations_with_tpu_strategies())
+      keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
   def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
                                                         use_numpy,
                                                         use_validation_data):
diff --git a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
index ed1e707d04a..e47b3ba519c 100644
--- a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
@@ -20,11 +20,11 @@ from __future__ import print_function
 import numpy as np
 from tensorflow.python import keras
 from tensorflow.python.distribute import combinations as ds_combinations
+from tensorflow.python.distribute import multi_process_runner
 from tensorflow.python.eager import context
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import keras_correctness_test_base
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
-from tensorflow.python.platform import test
 
 
 @testing_utils.run_all_without_tensor_float_32(
@@ -97,12 +97,14 @@ class DistributionStrategyCnnCorrectnessTest(
     return x_train, y_train, x_eval, y_eval, x_eval
 
   @ds_combinations.generate(
-      keras_correctness_test_base.all_strategy_and_input_config_combinations())
+      keras_correctness_test_base.all_strategy_and_input_config_combinations() +
+      keras_correctness_test_base.multi_worker_mirrored_eager())
   def test_cnn_correctness(self, distribution, use_numpy, use_validation_data):
     self.run_correctness_test(distribution, use_numpy, use_validation_data)
 
   @ds_combinations.generate(
-      keras_correctness_test_base.all_strategy_and_input_config_combinations())
+      keras_correctness_test_base.all_strategy_and_input_config_combinations() +
+      keras_correctness_test_base.multi_worker_mirrored_eager())
   def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy,
                                            use_validation_data):
     self.run_correctness_test(
@@ -112,7 +114,8 @@ class DistributionStrategyCnnCorrectnessTest(
         with_batch_norm='regular')
 
   @ds_combinations.generate(
-      keras_correctness_test_base.all_strategy_and_input_config_combinations())
+      keras_correctness_test_base.all_strategy_and_input_config_combinations() +
+      keras_correctness_test_base.multi_worker_mirrored_eager())
   def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy,
                                                 use_validation_data):
     if not context.executing_eagerly():
@@ -125,9 +128,10 @@ class DistributionStrategyCnnCorrectnessTest(
         with_batch_norm='sync')
 
   @ds_combinations.generate(
-      keras_correctness_test_base.test_combinations_with_tpu_strategies() +
       keras_correctness_test_base
-      .strategy_minus_tpu_and_input_config_combinations_eager())
+      .all_strategy_and_input_config_combinations_eager() +
+      keras_correctness_test_base.multi_worker_mirrored_eager() +
+      keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
   def test_cnn_correctness_with_partial_last_batch_eval(self, distribution,
                                                         use_numpy,
                                                         use_validation_data):
@@ -139,9 +143,10 @@ class DistributionStrategyCnnCorrectnessTest(
         training_epochs=1)
 
   @ds_combinations.generate(
-      keras_correctness_test_base.test_combinations_with_tpu_strategies() +
-      keras_correctness_test_base
-      .strategy_minus_tpu_and_input_config_combinations_eager())
+      keras_correctness_test_base.
+      all_strategy_and_input_config_combinations_eager() +
+      keras_correctness_test_base.multi_worker_mirrored_eager() +
+      keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
   def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval(
       self, distribution, use_numpy, use_validation_data):
     self.run_correctness_test(
@@ -153,4 +158,4 @@ class DistributionStrategyCnnCorrectnessTest(
 
 
 if __name__ == '__main__':
-  test.main()
+  multi_process_runner.test_main()
diff --git a/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py
index 199a1f390a4..2a1ec826b0d 100644
--- a/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py
@@ -93,7 +93,8 @@ class DistributionStrategyStatefulLstmModelCorrectnessTest(
 
   @ds_combinations.generate(
       combinations.times(
-          keras_correctness_test_base.test_combinations_with_tpu_strategies()))
+          keras_correctness_test_base
+          .test_combinations_with_tpu_strategies_graph()))
   def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
       self, distribution, use_numpy, use_validation_data):
     with self.assertRaisesRegex(