Add MWMS combinations to keras_embedding_model_correctness_test.
We can add the combinations to test_combinations_for_embedding_model(), once other tests that use the said combination also enable MWMS combinations. PiperOrigin-RevId: 332902076 Change-Id: I1d9df59862497894c691f07a02ab1b1fbdea0a05
This commit is contained in:
parent
f50646565f
commit
c1cc87436a
@ -18,11 +18,12 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.distribute import combinations as ds_combinations
|
from tensorflow.python.distribute import combinations as ds_combinations
|
||||||
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.keras.distribute import keras_correctness_test_base
|
from tensorflow.python.keras.distribute import keras_correctness_test_base
|
||||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
|
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
|
||||||
from tensorflow.python.platform import test
|
|
||||||
|
|
||||||
|
|
||||||
class DistributionStrategyEmbeddingModelCorrectnessTest(
|
class DistributionStrategyEmbeddingModelCorrectnessTest(
|
||||||
@ -56,7 +57,8 @@ class DistributionStrategyEmbeddingModelCorrectnessTest(
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@ds_combinations.generate(
|
@ds_combinations.generate(
|
||||||
keras_correctness_test_base.test_combinations_for_embedding_model())
|
keras_correctness_test_base.test_combinations_for_embedding_model() +
|
||||||
|
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||||
def test_embedding_model_correctness(self, distribution, use_numpy,
|
def test_embedding_model_correctness(self, distribution, use_numpy,
|
||||||
use_validation_data):
|
use_validation_data):
|
||||||
|
|
||||||
@ -64,7 +66,8 @@ class DistributionStrategyEmbeddingModelCorrectnessTest(
|
|||||||
self.run_correctness_test(distribution, use_numpy, use_validation_data)
|
self.run_correctness_test(distribution, use_numpy, use_validation_data)
|
||||||
|
|
||||||
@ds_combinations.generate(
|
@ds_combinations.generate(
|
||||||
keras_correctness_test_base.test_combinations_for_embedding_model())
|
keras_correctness_test_base.test_combinations_for_embedding_model() +
|
||||||
|
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||||
def test_embedding_time_distributed_model_correctness(
|
def test_embedding_time_distributed_model_correctness(
|
||||||
self, distribution, use_numpy, use_validation_data):
|
self, distribution, use_numpy, use_validation_data):
|
||||||
self.use_distributed_dense = True
|
self.use_distributed_dense = True
|
||||||
@ -146,11 +149,12 @@ class DistributionStrategySiameseEmbeddingModelCorrectnessTest(
|
|||||||
return x_train, y_train, x_predict
|
return x_train, y_train, x_predict
|
||||||
|
|
||||||
@ds_combinations.generate(
|
@ds_combinations.generate(
|
||||||
keras_correctness_test_base.test_combinations_for_embedding_model())
|
keras_correctness_test_base.test_combinations_for_embedding_model() +
|
||||||
|
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||||
def test_siamese_embedding_model_correctness(self, distribution, use_numpy,
|
def test_siamese_embedding_model_correctness(self, distribution, use_numpy,
|
||||||
use_validation_data):
|
use_validation_data):
|
||||||
self.run_correctness_test(distribution, use_numpy, use_validation_data)
|
self.run_correctness_test(distribution, use_numpy, use_validation_data)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
multi_process_runner.test_main()
|
||||||
|
Loading…
Reference in New Issue
Block a user