diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py index 8f019cdb7c9..2639b6a78b8 100644 --- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py +++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py @@ -74,11 +74,16 @@ def graph_mode_test_configuration(): def all_strategy_and_input_config_combinations(): return (combinations.times( - combinations.combine( - distribution=all_strategies), + combinations.combine(distribution=all_strategies), eager_mode_test_configuration() + graph_mode_test_configuration())) +def all_strategy_and_input_config_combinations_eager(): + return (combinations.times( + combinations.combine(distribution=all_strategies), + eager_mode_test_configuration())) + + def strategy_minus_tpu_and_input_config_combinations_eager(): return (combinations.times( combinations.combine( @@ -114,7 +119,7 @@ def test_combinations_for_embedding_model(): (eager_mode_test_configuration()))) -def test_combinations_with_tpu_strategies(): +def test_combinations_with_tpu_strategies_graph(): tpu_strategies = [ strategy_combinations.tpu_strategy, ] diff --git a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py index 621b7feadf7..e6581a82692 100644 --- a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py @@ -110,7 +110,8 @@ class TestDistributionStrategyDnnCorrectness( self.run_correctness_test(distribution, use_numpy, use_validation_data) @ds_combinations.generate( - keras_correctness_test_base.test_combinations_with_tpu_strategies() + + keras_correctness_test_base + .test_combinations_with_tpu_strategies_graph() + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_dnn_correctness_with_partial_last_batch_eval(self, distribution, use_numpy, @@ -309,7 +310,7 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( self.run_dynamic_lr_test(distribution) @ds_combinations.generate( - keras_correctness_test_base.test_combinations_with_tpu_strategies()) + keras_correctness_test_base.test_combinations_with_tpu_strategies_graph()) def test_dnn_correctness_with_partial_last_batch_eval(self, distribution, use_numpy, use_validation_data): diff --git a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py index ed1e707d04a..e47b3ba519c 100644 --- a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py @@ -20,11 +20,11 @@ from __future__ import print_function import numpy as np from tensorflow.python import keras from tensorflow.python.distribute import combinations as ds_combinations +from tensorflow.python.distribute import multi_process_runner from tensorflow.python.eager import context from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.optimizer_v2 import gradient_descent -from tensorflow.python.platform import test @testing_utils.run_all_without_tensor_float_32( @@ -97,12 +97,14 @@ class DistributionStrategyCnnCorrectnessTest( return x_train, y_train, x_eval, y_eval, x_eval @ds_combinations.generate( - keras_correctness_test_base.all_strategy_and_input_config_combinations()) + keras_correctness_test_base.all_strategy_and_input_config_combinations() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_cnn_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data) @ds_combinations.generate( - keras_correctness_test_base.all_strategy_and_input_config_combinations()) + keras_correctness_test_base.all_strategy_and_input_config_combinations() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test( @@ -112,7 +114,8 @@ class DistributionStrategyCnnCorrectnessTest( with_batch_norm='regular') @ds_combinations.generate( - keras_correctness_test_base.all_strategy_and_input_config_combinations()) + keras_correctness_test_base.all_strategy_and_input_config_combinations() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy, use_validation_data): if not context.executing_eagerly(): @@ -125,9 +128,10 @@ class DistributionStrategyCnnCorrectnessTest( with_batch_norm='sync') @ds_combinations.generate( - keras_correctness_test_base.test_combinations_with_tpu_strategies() + keras_correctness_test_base - .strategy_minus_tpu_and_input_config_combinations_eager()) + .all_strategy_and_input_config_combinations_eager() + + keras_correctness_test_base.multi_worker_mirrored_eager() + + keras_correctness_test_base.test_combinations_with_tpu_strategies_graph()) def test_cnn_correctness_with_partial_last_batch_eval(self, distribution, use_numpy, use_validation_data): @@ -139,9 +143,10 @@ class DistributionStrategyCnnCorrectnessTest( training_epochs=1) @ds_combinations.generate( - keras_correctness_test_base.test_combinations_with_tpu_strategies() + - keras_correctness_test_base - .strategy_minus_tpu_and_input_config_combinations_eager()) + keras_correctness_test_base. + all_strategy_and_input_config_combinations_eager() + + keras_correctness_test_base.multi_worker_mirrored_eager() + + keras_correctness_test_base.test_combinations_with_tpu_strategies_graph()) def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval( self, distribution, use_numpy, use_validation_data): self.run_correctness_test( @@ -153,4 +158,4 @@ class DistributionStrategyCnnCorrectnessTest( if __name__ == '__main__': - test.main() + multi_process_runner.test_main() diff --git a/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py index 199a1f390a4..2a1ec826b0d 100644 --- a/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_stateful_lstm_model_correctness_test.py @@ -93,7 +93,8 @@ class DistributionStrategyStatefulLstmModelCorrectnessTest( @ds_combinations.generate( combinations.times( - keras_correctness_test_base.test_combinations_with_tpu_strategies())) + keras_correctness_test_base + .test_combinations_with_tpu_strategies_graph())) def test_incorrectly_use_multiple_cores_for_stateful_lstm_model( self, distribution, use_numpy, use_validation_data): with self.assertRaisesRegex(