Add multi worker mirrored to keras_dnn_correctness_test
model.predict() doesn't support MWMS yet, so we skip predict() for MWMS. PiperOrigin-RevId: 332313925 Change-Id: I9e09e1bce496835e9aca74654af89cb54c2fb577
This commit is contained in:
parent
569d09e087
commit
7cfe00cf3a
@ -381,9 +381,11 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/distribute:collective_all_reduce_strategy",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//tensorflow/python/distribute:multi_process_runner",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/distribute:tpu_strategy",
|
||||
"//tensorflow/python/eager:context",
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
import six
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
@ -54,6 +55,14 @@ all_strategies = [
|
||||
]
|
||||
|
||||
|
||||
# TODO(b/159831559): add to all_strategies once all tests pass.
|
||||
multi_worker_mirrored = [
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
]
|
||||
|
||||
|
||||
def eager_mode_test_configuration():
|
||||
return combinations.combine(
|
||||
mode='eager', use_numpy=[True, False], use_validation_data=[True, False])
|
||||
@ -117,6 +126,18 @@ def test_combinations_with_tpu_strategies():
|
||||
graph_mode_test_configuration()))
|
||||
|
||||
|
||||
def multi_worker_mirrored_eager():
|
||||
return combinations.times(
|
||||
combinations.combine(distribution=multi_worker_mirrored),
|
||||
eager_mode_test_configuration())
|
||||
|
||||
|
||||
def multi_worker_mirrored_eager_and_graph():
|
||||
return combinations.times(
|
||||
combinations.combine(distribution=multi_worker_mirrored),
|
||||
eager_mode_test_configuration() + graph_mode_test_configuration())
|
||||
|
||||
|
||||
class MaybeDistributionScope(object):
|
||||
"""Provides a context allowing no distribution strategy."""
|
||||
|
||||
@ -263,7 +284,12 @@ def fit_eval_and_predict(initial_weights,
|
||||
|
||||
result['weights_1'] = model.get_weights()
|
||||
|
||||
if predict_inputs is not None:
|
||||
# TODO(b/157924053): Now model.predict() doesn't support
|
||||
# MultiWorkerMirroredStrategy. Enable model.predict() after it's supported.
|
||||
if predict_inputs is not None and not isinstance(
|
||||
distribution,
|
||||
(collective_all_reduce_strategy.CollectiveAllReduceStrategy,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategyV1)):
|
||||
# Check correctness of the result of predict() invoked
|
||||
# multiple times -- as for stateful models, result of
|
||||
# predict may differ for each batch.
|
||||
|
@ -22,20 +22,22 @@ from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.distribute import keras_correctness_test_base
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
def all_strategy_combinations_with_eager_and_graph_modes():
|
||||
return (combinations.combine(
|
||||
distribution=keras_correctness_test_base.all_strategies,
|
||||
mode=['graph', 'eager']))
|
||||
mode=['graph', 'eager']) + combinations.combine(
|
||||
distribution=keras_correctness_test_base.multi_worker_mirrored,
|
||||
mode='eager'))
|
||||
|
||||
|
||||
def all_strategy_combinations_with_graph_mode():
|
||||
@ -102,12 +104,14 @@ class TestDistributionStrategyDnnCorrectness(
|
||||
return x_train, y_train, x_eval, y_eval, x_predict
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations())
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data):
|
||||
self.run_correctness_test(distribution, use_numpy, use_validation_data)
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.test_combinations_with_tpu_strategies())
|
||||
keras_correctness_test_base.test_combinations_with_tpu_strategies() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
|
||||
use_numpy,
|
||||
use_validation_data):
|
||||
@ -116,7 +120,8 @@ class TestDistributionStrategyDnnCorrectness(
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base
|
||||
.strategy_minus_tpu_and_input_config_combinations_eager())
|
||||
.strategy_minus_tpu_and_input_config_combinations_eager() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||
def test_dnn_correctness_with_partial_last_batch(self, distribution,
|
||||
use_numpy,
|
||||
use_validation_data):
|
||||
@ -265,7 +270,8 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
|
||||
return model
|
||||
|
||||
@ds_combinations.generate(
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations())
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
|
||||
keras_correctness_test_base.multi_worker_mirrored_eager())
|
||||
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data):
|
||||
if (context.executing_eagerly()) or is_default_strategy(distribution):
|
||||
self.run_correctness_test(distribution, use_numpy, use_validation_data)
|
||||
@ -319,4 +325,4 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
multi_process_runner.test_main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user