Add MWMS combinations to keras_embedding_model_correctness_test.
We can add the combinations to test_combinations_for_embedding_model(), once other tests that use the said combination also enable MWMS combinations. PiperOrigin-RevId: 332902076 Change-Id: I1d9df59862497894c691f07a02ab1b1fbdea0a05
This commit is contained in:
parent
f50646565f
commit
c1cc87436a
@ -18,11 +18,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.keras.distribute import keras_correctness_test_base
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class DistributionStrategyEmbeddingModelCorrectnessTest(
|
||||
@ -56,7 +57,8 @@ class DistributionStrategyEmbeddingModelCorrectnessTest(
|
||||
return model
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.test_combinations_for_embedding_model())
|
||||
keras_correctness_test_base.test_combinations_for_embedding_model() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||
def test_embedding_model_correctness(self, distribution, use_numpy,
|
||||
use_validation_data):
|
||||
|
||||
@ -64,7 +66,8 @@ class DistributionStrategyEmbeddingModelCorrectnessTest(
|
||||
self.run_correctness_test(distribution, use_numpy, use_validation_data)
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.test_combinations_for_embedding_model())
|
||||
keras_correctness_test_base.test_combinations_for_embedding_model() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||
def test_embedding_time_distributed_model_correctness(
|
||||
self, distribution, use_numpy, use_validation_data):
|
||||
self.use_distributed_dense = True
|
||||
@ -146,11 +149,12 @@ class DistributionStrategySiameseEmbeddingModelCorrectnessTest(
|
||||
return x_train, y_train, x_predict
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.test_combinations_for_embedding_model())
|
||||
keras_correctness_test_base.test_combinations_for_embedding_model() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||
def test_siamese_embedding_model_correctness(self, distribution, use_numpy,
|
||||
use_validation_data):
|
||||
self.run_correctness_test(distribution, use_numpy, use_validation_data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
multi_process_runner.test_main()
|
||||
|
Loading…
Reference in New Issue
Block a user