tf.distribute: Add MultiWorkerMirroredStrategy to keras_image_correctness_test.

While I am here, enable some of the partial batch tests with TPUStrategy + eager, those were previously only tested for graph mode.

PiperOrigin-RevId: 332513819
Change-Id: I32a43191550b1ec978d86ad470fdfdd63d81083f
This commit is contained in:
Priya Gupta 2020-09-18 13:18:47 -07:00 committed by TensorFlower Gardener
parent 00302787b7
commit 19bb9fcff5
4 changed files with 28 additions and 16 deletions

View File

@ -74,11 +74,16 @@ def graph_mode_test_configuration():
def all_strategy_and_input_config_combinations(): def all_strategy_and_input_config_combinations():
return (combinations.times( return (combinations.times(
combinations.combine( combinations.combine(distribution=all_strategies),
distribution=all_strategies),
eager_mode_test_configuration() + graph_mode_test_configuration())) eager_mode_test_configuration() + graph_mode_test_configuration()))
def all_strategy_and_input_config_combinations_eager():
return (combinations.times(
combinations.combine(distribution=all_strategies),
eager_mode_test_configuration()))
def strategy_minus_tpu_and_input_config_combinations_eager(): def strategy_minus_tpu_and_input_config_combinations_eager():
return (combinations.times( return (combinations.times(
combinations.combine( combinations.combine(
@ -114,7 +119,7 @@ def test_combinations_for_embedding_model():
(eager_mode_test_configuration()))) (eager_mode_test_configuration())))
def test_combinations_with_tpu_strategies(): def test_combinations_with_tpu_strategies_graph():
tpu_strategies = [ tpu_strategies = [
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
] ]

View File

@ -110,7 +110,8 @@ class TestDistributionStrategyDnnCorrectness(
self.run_correctness_test(distribution, use_numpy, use_validation_data) self.run_correctness_test(distribution, use_numpy, use_validation_data)
@ds_combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies() + keras_correctness_test_base
.test_combinations_with_tpu_strategies_graph() +
keras_correctness_test_base.multi_worker_mirrored_eager()) keras_correctness_test_base.multi_worker_mirrored_eager())
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution, def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
use_numpy, use_numpy,
@ -309,7 +310,7 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
self.run_dynamic_lr_test(distribution) self.run_dynamic_lr_test(distribution)
@ds_combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies()) keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution, def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
use_numpy, use_numpy,
use_validation_data): use_validation_data):

View File

@ -20,11 +20,11 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.distribute import combinations as ds_combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.distribute import keras_correctness_test_base
from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.platform import test
@testing_utils.run_all_without_tensor_float_32( @testing_utils.run_all_without_tensor_float_32(
@ -97,12 +97,14 @@ class DistributionStrategyCnnCorrectnessTest(
return x_train, y_train, x_eval, y_eval, x_eval return x_train, y_train, x_eval, y_eval, x_eval
@ds_combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations()) keras_correctness_test_base.all_strategy_and_input_config_combinations() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_cnn_correctness(self, distribution, use_numpy, use_validation_data): def test_cnn_correctness(self, distribution, use_numpy, use_validation_data):
self.run_correctness_test(distribution, use_numpy, use_validation_data) self.run_correctness_test(distribution, use_numpy, use_validation_data)
@ds_combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations()) keras_correctness_test_base.all_strategy_and_input_config_combinations() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy, def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy,
use_validation_data): use_validation_data):
self.run_correctness_test( self.run_correctness_test(
@ -112,7 +114,8 @@ class DistributionStrategyCnnCorrectnessTest(
with_batch_norm='regular') with_batch_norm='regular')
@ds_combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations()) keras_correctness_test_base.all_strategy_and_input_config_combinations() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy, def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy,
use_validation_data): use_validation_data):
if not context.executing_eagerly(): if not context.executing_eagerly():
@ -125,9 +128,10 @@ class DistributionStrategyCnnCorrectnessTest(
with_batch_norm='sync') with_batch_norm='sync')
@ds_combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies() +
keras_correctness_test_base keras_correctness_test_base
.strategy_minus_tpu_and_input_config_combinations_eager()) .all_strategy_and_input_config_combinations_eager() +
keras_correctness_test_base.multi_worker_mirrored_eager() +
keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
def test_cnn_correctness_with_partial_last_batch_eval(self, distribution, def test_cnn_correctness_with_partial_last_batch_eval(self, distribution,
use_numpy, use_numpy,
use_validation_data): use_validation_data):
@ -139,9 +143,10 @@ class DistributionStrategyCnnCorrectnessTest(
training_epochs=1) training_epochs=1)
@ds_combinations.generate( @ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies() + keras_correctness_test_base.
keras_correctness_test_base all_strategy_and_input_config_combinations_eager() +
.strategy_minus_tpu_and_input_config_combinations_eager()) keras_correctness_test_base.multi_worker_mirrored_eager() +
keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval( def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval(
self, distribution, use_numpy, use_validation_data): self, distribution, use_numpy, use_validation_data):
self.run_correctness_test( self.run_correctness_test(
@ -153,4 +158,4 @@ class DistributionStrategyCnnCorrectnessTest(
if __name__ == '__main__': if __name__ == '__main__':
test.main() multi_process_runner.test_main()

View File

@ -93,7 +93,8 @@ class DistributionStrategyStatefulLstmModelCorrectnessTest(
@ds_combinations.generate( @ds_combinations.generate(
combinations.times( combinations.times(
keras_correctness_test_base.test_combinations_with_tpu_strategies())) keras_correctness_test_base
.test_combinations_with_tpu_strategies_graph()))
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model( def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
self, distribution, use_numpy, use_validation_data): self, distribution, use_numpy, use_validation_data):
with self.assertRaisesRegex( with self.assertRaisesRegex(