tf.distribute: Add MultiWorkerMirroredStrategy to keras_image_correctness_test.

While I am here, enable some of the partial batch tests with TPUStrategy + eager, those were previously only tested for graph mode.

PiperOrigin-RevId: 332513819
Change-Id: I32a43191550b1ec978d86ad470fdfdd63d81083f
This commit is contained in:
Priya Gupta 2020-09-18 13:18:47 -07:00 committed by TensorFlower Gardener
parent 00302787b7
commit 19bb9fcff5
4 changed files with 28 additions and 16 deletions

View File

@ -74,11 +74,16 @@ def graph_mode_test_configuration():
def all_strategy_and_input_config_combinations():
return (combinations.times(
combinations.combine(
distribution=all_strategies),
combinations.combine(distribution=all_strategies),
eager_mode_test_configuration() + graph_mode_test_configuration()))
def all_strategy_and_input_config_combinations_eager():
return (combinations.times(
combinations.combine(distribution=all_strategies),
eager_mode_test_configuration()))
def strategy_minus_tpu_and_input_config_combinations_eager():
return (combinations.times(
combinations.combine(
@ -114,7 +119,7 @@ def test_combinations_for_embedding_model():
(eager_mode_test_configuration())))
def test_combinations_with_tpu_strategies():
def test_combinations_with_tpu_strategies_graph():
tpu_strategies = [
strategy_combinations.tpu_strategy,
]

View File

@ -110,7 +110,8 @@ class TestDistributionStrategyDnnCorrectness(
self.run_correctness_test(distribution, use_numpy, use_validation_data)
@ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies() +
keras_correctness_test_base
.test_combinations_with_tpu_strategies_graph() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
use_numpy,
@ -309,7 +310,7 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
self.run_dynamic_lr_test(distribution)
@ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies())
keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
use_numpy,
use_validation_data):

View File

@ -20,11 +20,11 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.eager import context
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import keras_correctness_test_base
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.platform import test
@testing_utils.run_all_without_tensor_float_32(
@ -97,12 +97,14 @@ class DistributionStrategyCnnCorrectnessTest(
return x_train, y_train, x_eval, y_eval, x_eval
@ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations())
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_cnn_correctness(self, distribution, use_numpy, use_validation_data):
self.run_correctness_test(distribution, use_numpy, use_validation_data)
@ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations())
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy,
use_validation_data):
self.run_correctness_test(
@ -112,7 +114,8 @@ class DistributionStrategyCnnCorrectnessTest(
with_batch_norm='regular')
@ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations())
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy,
use_validation_data):
if not context.executing_eagerly():
@ -125,9 +128,10 @@ class DistributionStrategyCnnCorrectnessTest(
with_batch_norm='sync')
@ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies() +
keras_correctness_test_base
.strategy_minus_tpu_and_input_config_combinations_eager())
.all_strategy_and_input_config_combinations_eager() +
keras_correctness_test_base.multi_worker_mirrored_eager() +
keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
def test_cnn_correctness_with_partial_last_batch_eval(self, distribution,
use_numpy,
use_validation_data):
@ -139,9 +143,10 @@ class DistributionStrategyCnnCorrectnessTest(
training_epochs=1)
@ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies() +
keras_correctness_test_base
.strategy_minus_tpu_and_input_config_combinations_eager())
keras_correctness_test_base.
all_strategy_and_input_config_combinations_eager() +
keras_correctness_test_base.multi_worker_mirrored_eager() +
keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval(
self, distribution, use_numpy, use_validation_data):
self.run_correctness_test(
@ -153,4 +158,4 @@ class DistributionStrategyCnnCorrectnessTest(
if __name__ == '__main__':
test.main()
multi_process_runner.test_main()

View File

@ -93,7 +93,8 @@ class DistributionStrategyStatefulLstmModelCorrectnessTest(
@ds_combinations.generate(
combinations.times(
keras_correctness_test_base.test_combinations_with_tpu_strategies()))
keras_correctness_test_base
.test_combinations_with_tpu_strategies_graph()))
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
self, distribution, use_numpy, use_validation_data):
with self.assertRaisesRegex(