From c1cc87436ab8c98f3f40ad8f96667c76683ef3f1 Mon Sep 17 00:00:00 2001 From: Chenkai Kuang Date: Mon, 21 Sep 2020 12:04:32 -0700 Subject: [PATCH] Add MWMS combinations to keras_embedding_model_correctness_test. We can add the combinations to test_combinations_for_embedding_model(), once other tests that use the said combination also enable MWMS combinations. PiperOrigin-RevId: 332902076 Change-Id: I1d9df59862497894c691f07a02ab1b1fbdea0a05 --- .../keras_embedding_model_correctness_test.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py index 91cb1cc77fd..9d67fe660cc 100644 --- a/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_embedding_model_correctness_test.py @@ -18,11 +18,12 @@ from __future__ import division from __future__ import print_function import numpy as np + from tensorflow.python import keras from tensorflow.python.distribute import combinations as ds_combinations +from tensorflow.python.distribute import multi_process_runner from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras -from tensorflow.python.platform import test class DistributionStrategyEmbeddingModelCorrectnessTest( @@ -56,7 +57,8 @@ class DistributionStrategyEmbeddingModelCorrectnessTest( return model @ds_combinations.generate( - keras_correctness_test_base.test_combinations_for_embedding_model()) + keras_correctness_test_base.test_combinations_for_embedding_model() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_embedding_model_correctness(self, distribution, use_numpy, use_validation_data): @@ -64,7 +66,8 @@ class DistributionStrategyEmbeddingModelCorrectnessTest( self.run_correctness_test(distribution, use_numpy, use_validation_data) @ds_combinations.generate( - keras_correctness_test_base.test_combinations_for_embedding_model()) + keras_correctness_test_base.test_combinations_for_embedding_model() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_embedding_time_distributed_model_correctness( self, distribution, use_numpy, use_validation_data): self.use_distributed_dense = True @@ -146,11 +149,12 @@ class DistributionStrategySiameseEmbeddingModelCorrectnessTest( return x_train, y_train, x_predict @ds_combinations.generate( - keras_correctness_test_base.test_combinations_for_embedding_model()) + keras_correctness_test_base.test_combinations_for_embedding_model() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_siamese_embedding_model_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data) if __name__ == '__main__': - test.main() + multi_process_runner.test_main()