Remove usage of multi_process_runner_util.try_run_and_except_connection_error() as setting FAIL_FAST=false fixes the connection error issue. Make grpc_fail_fast an arg for MultiProcessRunner's __init__().
PiperOrigin-RevId: 281193957 Change-Id: I1c99cb90a15fdb26892d0ad37c533c186297b09d
This commit is contained in:
parent
06f9ec34c2
commit
2adcf83790
@ -1323,7 +1323,6 @@ py_library(
|
||||
srcs = ["multi_process_runner.py"],
|
||||
deps = [
|
||||
":multi_process_lib",
|
||||
":multi_process_runner_util",
|
||||
":multi_worker_test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:tf2",
|
||||
@ -1332,12 +1331,6 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "multi_process_runner_util",
|
||||
srcs = ["multi_process_runner_util.py"],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "multi_process_lib",
|
||||
srcs = ["multi_process_lib.py"],
|
||||
|
@ -90,6 +90,7 @@ class MultiProcessRunner(object):
|
||||
cluster_spec,
|
||||
max_run_time=None,
|
||||
capture_std_stream=False,
|
||||
grpc_fail_fast=False,
|
||||
args=None,
|
||||
kwargs=None):
|
||||
"""Creates a multi-process runner.
|
||||
@ -111,6 +112,8 @@ class MultiProcessRunner(object):
|
||||
level C/C++ code. So it can be delayed for arbitrarily long time.
|
||||
capture_std_stream: Boolean, whether the messages streamed to stdout and
|
||||
stderr in subprocesses are captured.
|
||||
grpc_fail_fast: Whether GRPC connection between processes should fail
|
||||
without retrying. Defaults to False.
|
||||
args: Positional arguments to be sent to functions run on processes.
|
||||
kwargs: Keyword arguments to be sent to functions run on processes.
|
||||
|
||||
@ -131,6 +134,7 @@ class MultiProcessRunner(object):
|
||||
self._cluster_spec = cluster_spec
|
||||
self._max_run_time = max_run_time
|
||||
self._capture_std_stream = capture_std_stream
|
||||
self._grpc_fail_fast = grpc_fail_fast
|
||||
self._args = args or ()
|
||||
self._kwargs = kwargs or {}
|
||||
self._processes = []
|
||||
@ -164,6 +168,7 @@ class MultiProcessRunner(object):
|
||||
|
||||
def _proc_func_wrapper(self, task_type, task_id, *arg, **kwargs):
|
||||
"""The wrapper function that actually gets run in child process(es)."""
|
||||
os.environ['GRPC_FAIL_FAST'] = str(self._grpc_fail_fast)
|
||||
os.environ['TF_CONFIG'] = json.dumps({
|
||||
'cluster': self._cluster_spec,
|
||||
'task': {
|
||||
@ -331,6 +336,7 @@ def run(proc_func,
|
||||
cluster_spec,
|
||||
max_run_time=None,
|
||||
capture_std_stream=False,
|
||||
grpc_fail_fast=False,
|
||||
args=None,
|
||||
kwargs=None): # pylint: disable=g-doc-args
|
||||
"""Runs functions in local child processes.
|
||||
@ -347,6 +353,7 @@ def run(proc_func,
|
||||
cluster_spec,
|
||||
max_run_time=max_run_time,
|
||||
capture_std_stream=capture_std_stream,
|
||||
grpc_fail_fast=grpc_fail_fast,
|
||||
args=args,
|
||||
kwargs=kwargs)
|
||||
runner.start()
|
||||
|
@ -1,39 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Util for multi-process runner."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
from tensorflow.python.framework import errors_impl
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def try_run_and_except_connection_error(test_obj):
|
||||
"""Context manager to skip cases not considered failures by the tests."""
|
||||
# TODO(b/142074107): Remove this try-except once within-loop fault-tolerance
|
||||
# is supported. This is temporarily needed to avoid test flakiness.
|
||||
try:
|
||||
yield
|
||||
except errors_impl.UnavailableError as e:
|
||||
if ('Connection reset by peer' in str(e) or 'Socket closed' in str(e) or
|
||||
'failed to connect to all addresses' in str(e)):
|
||||
test_obj.skipTest(
|
||||
'Skipping connection error between processes: {}'.format(str(e)))
|
||||
else:
|
||||
raise
|
@ -26,7 +26,6 @@ import numpy as np
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_process_runner_util
|
||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.eager import context
|
||||
@ -76,11 +75,9 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
|
||||
for _ in range(100):
|
||||
worker_step_fn()
|
||||
|
||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
||||
multi_process_runner.run(
|
||||
worker_fn,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=num_workers))
|
||||
multi_process_runner.run(
|
||||
worker_fn,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=num_workers))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -24,7 +24,6 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_process_runner_util
|
||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||
from tensorflow.python.keras import callbacks
|
||||
from tensorflow.python.keras.distribute import multi_worker_testing_utils
|
||||
@ -74,6 +73,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
def proc_model_checkpoint_saves_on_chief_but_not_otherwise(
|
||||
test_obj, file_format):
|
||||
|
||||
model, saving_filepath, train_ds, steps = _model_setup(
|
||||
test_obj, file_format)
|
||||
num_epoch = 2
|
||||
@ -104,12 +104,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
training_state.checkpoint_exists(saving_filepath),
|
||||
test_base.is_chief())
|
||||
|
||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
||||
multi_process_runner.run(
|
||||
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||
args=(self, file_format))
|
||||
multi_process_runner.run(
|
||||
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||
args=(self, file_format))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_tensorboard_saves_on_chief_but_not_otherwise(self, mode):
|
||||
@ -142,12 +140,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
test_obj.assertEqual(
|
||||
bool(file_io.list_directory(saving_filepath)), test_base.is_chief())
|
||||
|
||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
||||
multi_process_runner.run(
|
||||
proc_tensorboard_saves_on_chief_but_not_otherwise,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||
args=(self,))
|
||||
multi_process_runner.run(
|
||||
proc_tensorboard_saves_on_chief_but_not_otherwise,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||
args=(self,))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_tensorboard_can_still_save_to_temp_even_if_it_exists(self, mode):
|
||||
@ -173,12 +169,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
steps_per_epoch=steps,
|
||||
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
|
||||
|
||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
||||
multi_process_runner.run(
|
||||
proc_tensorboard_can_still_save_to_temp_even_if_it_exists,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||
args=(self,))
|
||||
multi_process_runner.run(
|
||||
proc_tensorboard_can_still_save_to_temp_even_if_it_exists,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||
args=(self,))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user