Add MWMS combinations to keras_embedding_model_correctness_test.

We can add the combinations to test_combinations_for_embedding_model(), once other tests that use the said combination also enable MWMS combinations.

PiperOrigin-RevId: 332902076
Change-Id: I1d9df59862497894c691f07a02ab1b1fbdea0a05
This commit is contained in:
Chenkai Kuang 2020-09-21 12:04:32 -07:00 committed by TensorFlower Gardener
parent f50646565f
commit c1cc87436a

View File

@ -18,11 +18,12 @@ from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.keras.distribute import keras_correctness_test_base
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
from tensorflow.python.platform import test
class DistributionStrategyEmbeddingModelCorrectnessTest(
@ -56,7 +57,8 @@ class DistributionStrategyEmbeddingModelCorrectnessTest(
return model
@ds_combinations.generate(
keras_correctness_test_base.test_combinations_for_embedding_model())
keras_correctness_test_base.test_combinations_for_embedding_model() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_embedding_model_correctness(self, distribution, use_numpy,
use_validation_data):
@ -64,7 +66,8 @@ class DistributionStrategyEmbeddingModelCorrectnessTest(
self.run_correctness_test(distribution, use_numpy, use_validation_data)
@ds_combinations.generate(
keras_correctness_test_base.test_combinations_for_embedding_model())
keras_correctness_test_base.test_combinations_for_embedding_model() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_embedding_time_distributed_model_correctness(
self, distribution, use_numpy, use_validation_data):
self.use_distributed_dense = True
@ -146,11 +149,12 @@ class DistributionStrategySiameseEmbeddingModelCorrectnessTest(
return x_train, y_train, x_predict
@ds_combinations.generate(
keras_correctness_test_base.test_combinations_for_embedding_model())
keras_correctness_test_base.test_combinations_for_embedding_model() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_siamese_embedding_model_correctness(self, distribution, use_numpy,
use_validation_data):
self.run_correctness_test(distribution, use_numpy, use_validation_data)
if __name__ == '__main__':
test.main()
multi_process_runner.test_main()