From 2adcf837907c473d2666bb57f66bd4bbae1456c5 Mon Sep 17 00:00:00 2001 From: Rick Chao Date: Mon, 18 Nov 2019 17:47:09 -0800 Subject: [PATCH] Remove usage of multi_process_runner_util.try_run_and_except_connection_error() as setting FAIL_FAST=false fixes the connection error issue. Make grpc_fail_fast an arg for MultiProcessRunner's __init__(). PiperOrigin-RevId: 281193957 Change-Id: I1c99cb90a15fdb26892d0ad37c533c186297b09d --- tensorflow/python/distribute/BUILD | 7 ---- .../python/distribute/multi_process_runner.py | 7 ++++ .../distribute/multi_process_runner_util.py | 39 ------------------- .../multi_worker_continuous_run_test.py | 9 ++--- .../multi_worker_callback_tf2_test.py | 32 +++++++-------- 5 files changed, 23 insertions(+), 71 deletions(-) delete mode 100644 tensorflow/python/distribute/multi_process_runner_util.py diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index fbe9ca3fa2e..e5bb8ace6e3 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1323,7 +1323,6 @@ py_library( srcs = ["multi_process_runner.py"], deps = [ ":multi_process_lib", - ":multi_process_runner_util", ":multi_worker_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:tf2", @@ -1332,12 +1331,6 @@ py_library( ], ) -py_library( - name = "multi_process_runner_util", - srcs = ["multi_process_runner_util.py"], - deps = [], -) - py_library( name = "multi_process_lib", srcs = ["multi_process_lib.py"], diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index c57741eadd1..d5da5811eac 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -90,6 +90,7 @@ class MultiProcessRunner(object): cluster_spec, max_run_time=None, capture_std_stream=False, + grpc_fail_fast=False, args=None, kwargs=None): """Creates a multi-process runner. @@ -111,6 +112,8 @@ class MultiProcessRunner(object): level C/C++ code. So it can be delayed for arbitrarily long time. capture_std_stream: Boolean, whether the messages streamed to stdout and stderr in subprocesses are captured. + grpc_fail_fast: Whether GRPC connection between processes should fail + without retrying. Defaults to False. args: Positional arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes. @@ -131,6 +134,7 @@ class MultiProcessRunner(object): self._cluster_spec = cluster_spec self._max_run_time = max_run_time self._capture_std_stream = capture_std_stream + self._grpc_fail_fast = grpc_fail_fast self._args = args or () self._kwargs = kwargs or {} self._processes = [] @@ -164,6 +168,7 @@ class MultiProcessRunner(object): def _proc_func_wrapper(self, task_type, task_id, *arg, **kwargs): """The wrapper function that actually gets run in child process(es).""" + os.environ['GRPC_FAIL_FAST'] = str(self._grpc_fail_fast) os.environ['TF_CONFIG'] = json.dumps({ 'cluster': self._cluster_spec, 'task': { @@ -331,6 +336,7 @@ def run(proc_func, cluster_spec, max_run_time=None, capture_std_stream=False, + grpc_fail_fast=False, args=None, kwargs=None): # pylint: disable=g-doc-args """Runs functions in local child processes. @@ -347,6 +353,7 @@ def run(proc_func, cluster_spec, max_run_time=max_run_time, capture_std_stream=capture_std_stream, + grpc_fail_fast=grpc_fail_fast, args=args, kwargs=kwargs) runner.start() diff --git a/tensorflow/python/distribute/multi_process_runner_util.py b/tensorflow/python/distribute/multi_process_runner_util.py deleted file mode 100644 index 47e35713b1b..00000000000 --- a/tensorflow/python/distribute/multi_process_runner_util.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Util for multi-process runner.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import contextlib - -from tensorflow.python.framework import errors_impl - - -@contextlib.contextmanager -def try_run_and_except_connection_error(test_obj): - """Context manager to skip cases not considered failures by the tests.""" - # TODO(b/142074107): Remove this try-except once within-loop fault-tolerance - # is supported. This is temporarily needed to avoid test flakiness. - try: - yield - except errors_impl.UnavailableError as e: - if ('Connection reset by peer' in str(e) or 'Socket closed' in str(e) or - 'failed to connect to all addresses' in str(e)): - test_obj.skipTest( - 'Skipping connection error between processes: {}'.format(str(e))) - else: - raise diff --git a/tensorflow/python/distribute/multi_worker_continuous_run_test.py b/tensorflow/python/distribute/multi_worker_continuous_run_test.py index 9668bc23351..cca0ef91a5a 100644 --- a/tensorflow/python/distribute/multi_worker_continuous_run_test.py +++ b/tensorflow/python/distribute/multi_worker_continuous_run_test.py @@ -26,7 +26,6 @@ import numpy as np from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations from tensorflow.python.distribute import multi_process_runner -from tensorflow.python.distribute import multi_process_runner_util from tensorflow.python.distribute import multi_worker_test_base as test_base from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context @@ -76,11 +75,9 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): for _ in range(100): worker_step_fn() - # TODO(b/141948186): Remove this `with` block once b/141948186 is resolved. - with multi_process_runner_util.try_run_and_except_connection_error(self): - multi_process_runner.run( - worker_fn, - cluster_spec=test_base.create_cluster_spec(num_workers=num_workers)) + multi_process_runner.run( + worker_fn, + cluster_spec=test_base.create_cluster_spec(num_workers=num_workers)) if __name__ == '__main__': diff --git a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py index 91bd8c91dfc..a4f53f5958c 100644 --- a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py @@ -24,7 +24,6 @@ from absl.testing import parameterized from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy from tensorflow.python.distribute import combinations from tensorflow.python.distribute import multi_process_runner -from tensorflow.python.distribute import multi_process_runner_util from tensorflow.python.distribute import multi_worker_test_base as test_base from tensorflow.python.keras import callbacks from tensorflow.python.keras.distribute import multi_worker_testing_utils @@ -74,6 +73,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): def proc_model_checkpoint_saves_on_chief_but_not_otherwise( test_obj, file_format): + model, saving_filepath, train_ds, steps = _model_setup( test_obj, file_format) num_epoch = 2 @@ -104,12 +104,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): training_state.checkpoint_exists(saving_filepath), test_base.is_chief()) - # TODO(b/141948186): Remove this `with` block once b/141948186 is resolved. - with multi_process_runner_util.try_run_and_except_connection_error(self): - multi_process_runner.run( - proc_model_checkpoint_saves_on_chief_but_not_otherwise, - cluster_spec=test_base.create_cluster_spec(num_workers=2), - args=(self, file_format)) + multi_process_runner.run( + proc_model_checkpoint_saves_on_chief_but_not_otherwise, + cluster_spec=test_base.create_cluster_spec(num_workers=2), + args=(self, file_format)) @combinations.generate(combinations.combine(mode=['eager'])) def test_tensorboard_saves_on_chief_but_not_otherwise(self, mode): @@ -142,12 +140,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): test_obj.assertEqual( bool(file_io.list_directory(saving_filepath)), test_base.is_chief()) - # TODO(b/141948186): Remove this `with` block once b/141948186 is resolved. - with multi_process_runner_util.try_run_and_except_connection_error(self): - multi_process_runner.run( - proc_tensorboard_saves_on_chief_but_not_otherwise, - cluster_spec=test_base.create_cluster_spec(num_workers=2), - args=(self,)) + multi_process_runner.run( + proc_tensorboard_saves_on_chief_but_not_otherwise, + cluster_spec=test_base.create_cluster_spec(num_workers=2), + args=(self,)) @combinations.generate(combinations.combine(mode=['eager'])) def test_tensorboard_can_still_save_to_temp_even_if_it_exists(self, mode): @@ -173,12 +169,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): steps_per_epoch=steps, callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)]) - # TODO(b/141948186): Remove this `with` block once b/141948186 is resolved. - with multi_process_runner_util.try_run_and_except_connection_error(self): - multi_process_runner.run( - proc_tensorboard_can_still_save_to_temp_even_if_it_exists, - cluster_spec=test_base.create_cluster_spec(num_workers=2), - args=(self,)) + multi_process_runner.run( + proc_tensorboard_can_still_save_to_temp_even_if_it_exists, + cluster_spec=test_base.create_cluster_spec(num_workers=2), + args=(self,)) if __name__ == '__main__':