Add multi worker mirrored to keras_dnn_correctness_test

model.predict() doesn't support MWMS yet, so we skip predict() for MWMS.

PiperOrigin-RevId: 332313925
Change-Id: I9e09e1bce496835e9aca74654af89cb54c2fb577
This commit is contained in:
Ran Chen 2020-09-17 14:16:16 -07:00 committed by TensorFlower Gardener
parent 569d09e087
commit 7cfe00cf3a
3 changed files with 42 additions and 8 deletions

View File

@ -381,9 +381,11 @@ py_library(
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/distribute:collective_all_reduce_strategy",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:multi_process_runner",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute:tpu_strategy",
"//tensorflow/python/eager:context",

View File

@ -24,6 +24,7 @@ import numpy as np
import six
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import strategy_combinations
@ -54,6 +55,14 @@ all_strategies = [
]
# TODO(b/159831559): add to all_strategies once all tests pass.
multi_worker_mirrored = [
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
]
def eager_mode_test_configuration():
return combinations.combine(
mode='eager', use_numpy=[True, False], use_validation_data=[True, False])
@ -117,6 +126,18 @@ def test_combinations_with_tpu_strategies():
graph_mode_test_configuration()))
def multi_worker_mirrored_eager():
return combinations.times(
combinations.combine(distribution=multi_worker_mirrored),
eager_mode_test_configuration())
def multi_worker_mirrored_eager_and_graph():
return combinations.times(
combinations.combine(distribution=multi_worker_mirrored),
eager_mode_test_configuration() + graph_mode_test_configuration())
class MaybeDistributionScope(object):
"""Provides a context allowing no distribution strategy."""
@ -263,7 +284,12 @@ def fit_eval_and_predict(initial_weights,
result['weights_1'] = model.get_weights()
if predict_inputs is not None:
# TODO(b/157924053): Now model.predict() doesn't support
# MultiWorkerMirroredStrategy. Enable model.predict() after it's supported.
if predict_inputs is not None and not isinstance(
distribution,
(collective_all_reduce_strategy.CollectiveAllReduceStrategy,
collective_all_reduce_strategy.CollectiveAllReduceStrategyV1)):
# Check correctness of the result of predict() invoked
# multiple times -- as for stateful models, result of
# predict may differ for each batch.

View File

@ -22,20 +22,22 @@ from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.eager import context
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import keras_correctness_test_base
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
def all_strategy_combinations_with_eager_and_graph_modes():
return (combinations.combine(
distribution=keras_correctness_test_base.all_strategies,
mode=['graph', 'eager']))
mode=['graph', 'eager']) + combinations.combine(
distribution=keras_correctness_test_base.multi_worker_mirrored,
mode='eager'))
def all_strategy_combinations_with_graph_mode():
@ -102,12 +104,14 @@ class TestDistributionStrategyDnnCorrectness(
return x_train, y_train, x_eval, y_eval, x_predict
@ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations())
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data):
self.run_correctness_test(distribution, use_numpy, use_validation_data)
@ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies())
keras_correctness_test_base.test_combinations_with_tpu_strategies() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
use_numpy,
use_validation_data):
@ -116,7 +120,8 @@ class TestDistributionStrategyDnnCorrectness(
@ds_combinations.generate(
keras_correctness_test_base
.strategy_minus_tpu_and_input_config_combinations_eager())
.strategy_minus_tpu_and_input_config_combinations_eager() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_dnn_correctness_with_partial_last_batch(self, distribution,
use_numpy,
use_validation_data):
@ -265,7 +270,8 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
return model
@ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations())
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data):
if (context.executing_eagerly()) or is_default_strategy(distribution):
self.run_correctness_test(distribution, use_numpy, use_validation_data)
@ -319,4 +325,4 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
if __name__ == '__main__':
test.main()
multi_process_runner.test_main()