diff --git a/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py index 91cb1cc77fd..9d67fe660cc 100644 --- a/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py @@ -18,11 +18,12 @@ from __future__ import division from __future__ import print_function import numpy as np + from tensorflow.python import keras from tensorflow.python.distribute import combinations as ds_combinations +from tensorflow.python.distribute import multi_process_runner from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras -from tensorflow.python.platform import test class DistributionStrategyEmbeddingModelCorrectnessTest( @@ -56,7 +57,8 @@ class DistributionStrategyEmbeddingModelCorrectnessTest( return model @ds_combinations.generate( - keras_correctness_test_base.test_combinations_for_embedding_model()) + keras_correctness_test_base.test_combinations_for_embedding_model() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_embedding_model_correctness(self, distribution, use_numpy, use_validation_data): @@ -64,7 +66,8 @@ class DistributionStrategyEmbeddingModelCorrectnessTest( self.run_correctness_test(distribution, use_numpy, use_validation_data) @ds_combinations.generate( - keras_correctness_test_base.test_combinations_for_embedding_model()) + keras_correctness_test_base.test_combinations_for_embedding_model() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_embedding_time_distributed_model_correctness( self, distribution, use_numpy, use_validation_data): self.use_distributed_dense = True @@ -146,11 +149,12 @@ class DistributionStrategySiameseEmbeddingModelCorrectnessTest( return x_train, y_train, x_predict @ds_combinations.generate( - keras_correctness_test_base.test_combinations_for_embedding_model()) + keras_correctness_test_base.test_combinations_for_embedding_model() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_siamese_embedding_model_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data) if __name__ == '__main__': - test.main() + multi_process_runner.test_main()