From 7cfe00cf3af3bce10185a667e6b15fc90f5edb52 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Thu, 17 Sep 2020 14:16:16 -0700 Subject: [PATCH] Add multi worker mirrored to keras_dnn_correctness_test model.predict() doesn't support MWMS yet, so we skip predict() for MWMS. PiperOrigin-RevId: 332313925 Change-Id: I9e09e1bce496835e9aca74654af89cb54c2fb577 --- tensorflow/python/keras/distribute/BUILD | 2 ++ .../distribute/keras_correctness_test_base.py | 28 ++++++++++++++++++- .../distribute/keras_dnn_correctness_test.py | 20 ++++++++----- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 43b5c690096..f314013e30d 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -381,9 +381,11 @@ py_library( deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:training", + "//tensorflow/python/distribute:collective_all_reduce_strategy", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:multi_process_runner", "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/distribute:tpu_strategy", "//tensorflow/python/eager:context", diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py index 5af1e2806c0..825e94b9eba 100644 --- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py +++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py @@ -24,6 +24,7 @@ import numpy as np import six from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import strategy_combinations @@ -54,6 +55,14 @@ all_strategies = [ ] +# TODO(b/159831559): add to all_strategies once all tests pass. +multi_worker_mirrored = [ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, +] + + def eager_mode_test_configuration(): return combinations.combine( mode='eager', use_numpy=[True, False], use_validation_data=[True, False]) @@ -117,6 +126,18 @@ def test_combinations_with_tpu_strategies(): graph_mode_test_configuration())) +def multi_worker_mirrored_eager(): + return combinations.times( + combinations.combine(distribution=multi_worker_mirrored), + eager_mode_test_configuration()) + + +def multi_worker_mirrored_eager_and_graph(): + return combinations.times( + combinations.combine(distribution=multi_worker_mirrored), + eager_mode_test_configuration() + graph_mode_test_configuration()) + + class MaybeDistributionScope(object): """Provides a context allowing no distribution strategy.""" @@ -263,7 +284,12 @@ def fit_eval_and_predict(initial_weights, result['weights_1'] = model.get_weights() - if predict_inputs is not None: + # TODO(b/157924053): Now model.predict() doesn't support + # MultiWorkerMirroredStrategy. Enable model.predict() after it's supported. + if predict_inputs is not None and not isinstance( + distribution, + (collective_all_reduce_strategy.CollectiveAllReduceStrategy, + collective_all_reduce_strategy.CollectiveAllReduceStrategyV1)): # Check correctness of the result of predict() invoked # multiple times -- as for stateful models, result of # predict may differ for each batch. diff --git a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py index c8b40ab4601..621b7feadf7 100644 --- a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py @@ -22,20 +22,22 @@ from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations as ds_combinations from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import multi_process_runner from tensorflow.python.eager import context from tensorflow.python.framework import test_combinations as combinations from tensorflow.python.keras import backend as K from tensorflow.python.keras import testing_utils from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras -from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent def all_strategy_combinations_with_eager_and_graph_modes(): return (combinations.combine( distribution=keras_correctness_test_base.all_strategies, - mode=['graph', 'eager'])) + mode=['graph', 'eager']) + combinations.combine( + distribution=keras_correctness_test_base.multi_worker_mirrored, + mode='eager')) def all_strategy_combinations_with_graph_mode(): @@ -102,12 +104,14 @@ class TestDistributionStrategyDnnCorrectness( return x_train, y_train, x_eval, y_eval, x_predict @ds_combinations.generate( - keras_correctness_test_base.all_strategy_and_input_config_combinations()) + keras_correctness_test_base.all_strategy_and_input_config_combinations() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): self.run_correctness_test(distribution, use_numpy, use_validation_data) @ds_combinations.generate( - keras_correctness_test_base.test_combinations_with_tpu_strategies()) + keras_correctness_test_base.test_combinations_with_tpu_strategies() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_dnn_correctness_with_partial_last_batch_eval(self, distribution, use_numpy, use_validation_data): @@ -116,7 +120,8 @@ class TestDistributionStrategyDnnCorrectness( @ds_combinations.generate( keras_correctness_test_base - .strategy_minus_tpu_and_input_config_combinations_eager()) + .strategy_minus_tpu_and_input_config_combinations_eager() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_dnn_correctness_with_partial_last_batch(self, distribution, use_numpy, use_validation_data): @@ -265,7 +270,8 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( return model @ds_combinations.generate( - keras_correctness_test_base.all_strategy_and_input_config_combinations()) + keras_correctness_test_base.all_strategy_and_input_config_combinations() + + keras_correctness_test_base.multi_worker_mirrored_eager()) def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): if (context.executing_eagerly()) or is_default_strategy(distribution): self.run_correctness_test(distribution, use_numpy, use_validation_data) @@ -319,4 +325,4 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( if __name__ == '__main__': - test.main() + multi_process_runner.test_main()