tf.distribute: Add MultiWorkerMirroredStrategy to keras_image_correctness_test.
While I am here, enable some of the partial batch tests with TPUStrategy + eager, those were previously only tested for graph mode. PiperOrigin-RevId: 332513819 Change-Id: I32a43191550b1ec978d86ad470fdfdd63d81083f
This commit is contained in:
parent
00302787b7
commit
19bb9fcff5
@ -74,11 +74,16 @@ def graph_mode_test_configuration():
|
||||
|
||||
def all_strategy_and_input_config_combinations():
|
||||
return (combinations.times(
|
||||
combinations.combine(
|
||||
distribution=all_strategies),
|
||||
combinations.combine(distribution=all_strategies),
|
||||
eager_mode_test_configuration() + graph_mode_test_configuration()))
|
||||
|
||||
|
||||
def all_strategy_and_input_config_combinations_eager():
|
||||
return (combinations.times(
|
||||
combinations.combine(distribution=all_strategies),
|
||||
eager_mode_test_configuration()))
|
||||
|
||||
|
||||
def strategy_minus_tpu_and_input_config_combinations_eager():
|
||||
return (combinations.times(
|
||||
combinations.combine(
|
||||
@ -114,7 +119,7 @@ def test_combinations_for_embedding_model():
|
||||
(eager_mode_test_configuration())))
|
||||
|
||||
|
||||
def test_combinations_with_tpu_strategies():
|
||||
def test_combinations_with_tpu_strategies_graph():
|
||||
tpu_strategies = [
|
||||
strategy_combinations.tpu_strategy,
|
||||
]
|
||||
|
@ -110,7 +110,8 @@ class TestDistributionStrategyDnnCorrectness(
|
||||
self.run_correctness_test(distribution, use_numpy, use_validation_data)
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.test_combinations_with_tpu_strategies() +
|
||||
keras_correctness_test_base
|
||||
.test_combinations_with_tpu_strategies_graph() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
|
||||
use_numpy,
|
||||
@ -309,7 +310,7 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
|
||||
self.run_dynamic_lr_test(distribution)
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.test_combinations_with_tpu_strategies())
|
||||
keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
|
||||
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
|
||||
use_numpy,
|
||||
use_validation_data):
|
||||
|
@ -20,11 +20,11 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.distribute import keras_correctness_test_base
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@testing_utils.run_all_without_tensor_float_32(
|
||||
@ -97,12 +97,14 @@ class DistributionStrategyCnnCorrectnessTest(
|
||||
return x_train, y_train, x_eval, y_eval, x_eval
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations())
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||
def test_cnn_correctness(self, distribution, use_numpy, use_validation_data):
|
||||
self.run_correctness_test(distribution, use_numpy, use_validation_data)
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations())
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||
def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy,
|
||||
use_validation_data):
|
||||
self.run_correctness_test(
|
||||
@ -112,7 +114,8 @@ class DistributionStrategyCnnCorrectnessTest(
|
||||
with_batch_norm='regular')
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations())
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||
def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy,
|
||||
use_validation_data):
|
||||
if not context.executing_eagerly():
|
||||
@ -125,9 +128,10 @@ class DistributionStrategyCnnCorrectnessTest(
|
||||
with_batch_norm='sync')
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.test_combinations_with_tpu_strategies() +
|
||||
keras_correctness_test_base
|
||||
.strategy_minus_tpu_and_input_config_combinations_eager())
|
||||
.all_strategy_and_input_config_combinations_eager() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager() +
|
||||
keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
|
||||
def test_cnn_correctness_with_partial_last_batch_eval(self, distribution,
|
||||
use_numpy,
|
||||
use_validation_data):
|
||||
@ -139,9 +143,10 @@ class DistributionStrategyCnnCorrectnessTest(
|
||||
training_epochs=1)
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.test_combinations_with_tpu_strategies() +
|
||||
keras_correctness_test_base
|
||||
.strategy_minus_tpu_and_input_config_combinations_eager())
|
||||
keras_correctness_test_base.
|
||||
all_strategy_and_input_config_combinations_eager() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager() +
|
||||
keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
|
||||
def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval(
|
||||
self, distribution, use_numpy, use_validation_data):
|
||||
self.run_correctness_test(
|
||||
@ -153,4 +158,4 @@ class DistributionStrategyCnnCorrectnessTest(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
multi_process_runner.test_main()
|
||||
|
@ -93,7 +93,8 @@ class DistributionStrategyStatefulLstmModelCorrectnessTest(
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.times(
|
||||
keras_correctness_test_base.test_combinations_with_tpu_strategies()))
|
||||
keras_correctness_test_base
|
||||
.test_combinations_with_tpu_strategies_graph()))
|
||||
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
|
||||
self, distribution, use_numpy, use_validation_data):
|
||||
with self.assertRaisesRegex(
|
||||
|
Loading…
Reference in New Issue
Block a user