Remove usage of multi_process_runner_util.try_run_and_except_connection_error() as setting FAIL_FAST=false fixes the connection error issue. Make grpc_fail_fast an arg for MultiProcessRunner's __init__().
PiperOrigin-RevId: 281193957 Change-Id: I1c99cb90a15fdb26892d0ad37c533c186297b09d
This commit is contained in:
parent
06f9ec34c2
commit
2adcf83790
@ -1323,7 +1323,6 @@ py_library(
|
|||||||
srcs = ["multi_process_runner.py"],
|
srcs = ["multi_process_runner.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":multi_process_lib",
|
":multi_process_lib",
|
||||||
":multi_process_runner_util",
|
|
||||||
":multi_worker_test_base",
|
":multi_worker_test_base",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:tf2",
|
"//tensorflow/python:tf2",
|
||||||
@ -1332,12 +1331,6 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "multi_process_runner_util",
|
|
||||||
srcs = ["multi_process_runner_util.py"],
|
|
||||||
deps = [],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "multi_process_lib",
|
name = "multi_process_lib",
|
||||||
srcs = ["multi_process_lib.py"],
|
srcs = ["multi_process_lib.py"],
|
||||||
|
@ -90,6 +90,7 @@ class MultiProcessRunner(object):
|
|||||||
cluster_spec,
|
cluster_spec,
|
||||||
max_run_time=None,
|
max_run_time=None,
|
||||||
capture_std_stream=False,
|
capture_std_stream=False,
|
||||||
|
grpc_fail_fast=False,
|
||||||
args=None,
|
args=None,
|
||||||
kwargs=None):
|
kwargs=None):
|
||||||
"""Creates a multi-process runner.
|
"""Creates a multi-process runner.
|
||||||
@ -111,6 +112,8 @@ class MultiProcessRunner(object):
|
|||||||
level C/C++ code. So it can be delayed for arbitrarily long time.
|
level C/C++ code. So it can be delayed for arbitrarily long time.
|
||||||
capture_std_stream: Boolean, whether the messages streamed to stdout and
|
capture_std_stream: Boolean, whether the messages streamed to stdout and
|
||||||
stderr in subprocesses are captured.
|
stderr in subprocesses are captured.
|
||||||
|
grpc_fail_fast: Whether GRPC connection between processes should fail
|
||||||
|
without retrying. Defaults to False.
|
||||||
args: Positional arguments to be sent to functions run on processes.
|
args: Positional arguments to be sent to functions run on processes.
|
||||||
kwargs: Keyword arguments to be sent to functions run on processes.
|
kwargs: Keyword arguments to be sent to functions run on processes.
|
||||||
|
|
||||||
@ -131,6 +134,7 @@ class MultiProcessRunner(object):
|
|||||||
self._cluster_spec = cluster_spec
|
self._cluster_spec = cluster_spec
|
||||||
self._max_run_time = max_run_time
|
self._max_run_time = max_run_time
|
||||||
self._capture_std_stream = capture_std_stream
|
self._capture_std_stream = capture_std_stream
|
||||||
|
self._grpc_fail_fast = grpc_fail_fast
|
||||||
self._args = args or ()
|
self._args = args or ()
|
||||||
self._kwargs = kwargs or {}
|
self._kwargs = kwargs or {}
|
||||||
self._processes = []
|
self._processes = []
|
||||||
@ -164,6 +168,7 @@ class MultiProcessRunner(object):
|
|||||||
|
|
||||||
def _proc_func_wrapper(self, task_type, task_id, *arg, **kwargs):
|
def _proc_func_wrapper(self, task_type, task_id, *arg, **kwargs):
|
||||||
"""The wrapper function that actually gets run in child process(es)."""
|
"""The wrapper function that actually gets run in child process(es)."""
|
||||||
|
os.environ['GRPC_FAIL_FAST'] = str(self._grpc_fail_fast)
|
||||||
os.environ['TF_CONFIG'] = json.dumps({
|
os.environ['TF_CONFIG'] = json.dumps({
|
||||||
'cluster': self._cluster_spec,
|
'cluster': self._cluster_spec,
|
||||||
'task': {
|
'task': {
|
||||||
@ -331,6 +336,7 @@ def run(proc_func,
|
|||||||
cluster_spec,
|
cluster_spec,
|
||||||
max_run_time=None,
|
max_run_time=None,
|
||||||
capture_std_stream=False,
|
capture_std_stream=False,
|
||||||
|
grpc_fail_fast=False,
|
||||||
args=None,
|
args=None,
|
||||||
kwargs=None): # pylint: disable=g-doc-args
|
kwargs=None): # pylint: disable=g-doc-args
|
||||||
"""Runs functions in local child processes.
|
"""Runs functions in local child processes.
|
||||||
@ -347,6 +353,7 @@ def run(proc_func,
|
|||||||
cluster_spec,
|
cluster_spec,
|
||||||
max_run_time=max_run_time,
|
max_run_time=max_run_time,
|
||||||
capture_std_stream=capture_std_stream,
|
capture_std_stream=capture_std_stream,
|
||||||
|
grpc_fail_fast=grpc_fail_fast,
|
||||||
args=args,
|
args=args,
|
||||||
kwargs=kwargs)
|
kwargs=kwargs)
|
||||||
runner.start()
|
runner.start()
|
||||||
|
@ -1,39 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Util for multi-process runner."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
|
|
||||||
from tensorflow.python.framework import errors_impl
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def try_run_and_except_connection_error(test_obj):
|
|
||||||
"""Context manager to skip cases not considered failures by the tests."""
|
|
||||||
# TODO(b/142074107): Remove this try-except once within-loop fault-tolerance
|
|
||||||
# is supported. This is temporarily needed to avoid test flakiness.
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except errors_impl.UnavailableError as e:
|
|
||||||
if ('Connection reset by peer' in str(e) or 'Socket closed' in str(e) or
|
|
||||||
'failed to connect to all addresses' in str(e)):
|
|
||||||
test_obj.skipTest(
|
|
||||||
'Skipping connection error between processes: {}'.format(str(e)))
|
|
||||||
else:
|
|
||||||
raise
|
|
@ -26,7 +26,6 @@ import numpy as np
|
|||||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
from tensorflow.python.distribute import multi_process_runner
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.distribute import multi_process_runner_util
|
|
||||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -76,11 +75,9 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase):
|
|||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
worker_step_fn()
|
worker_step_fn()
|
||||||
|
|
||||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
multi_process_runner.run(
|
||||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
worker_fn,
|
||||||
multi_process_runner.run(
|
cluster_spec=test_base.create_cluster_spec(num_workers=num_workers))
|
||||||
worker_fn,
|
|
||||||
cluster_spec=test_base.create_cluster_spec(num_workers=num_workers))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -24,7 +24,6 @@ from absl.testing import parameterized
|
|||||||
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
from tensorflow.python.distribute import multi_process_runner
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.distribute import multi_process_runner_util
|
|
||||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||||
from tensorflow.python.keras import callbacks
|
from tensorflow.python.keras import callbacks
|
||||||
from tensorflow.python.keras.distribute import multi_worker_testing_utils
|
from tensorflow.python.keras.distribute import multi_worker_testing_utils
|
||||||
@ -74,6 +73,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
|
|
||||||
def proc_model_checkpoint_saves_on_chief_but_not_otherwise(
|
def proc_model_checkpoint_saves_on_chief_but_not_otherwise(
|
||||||
test_obj, file_format):
|
test_obj, file_format):
|
||||||
|
|
||||||
model, saving_filepath, train_ds, steps = _model_setup(
|
model, saving_filepath, train_ds, steps = _model_setup(
|
||||||
test_obj, file_format)
|
test_obj, file_format)
|
||||||
num_epoch = 2
|
num_epoch = 2
|
||||||
@ -104,12 +104,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
training_state.checkpoint_exists(saving_filepath),
|
training_state.checkpoint_exists(saving_filepath),
|
||||||
test_base.is_chief())
|
test_base.is_chief())
|
||||||
|
|
||||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
multi_process_runner.run(
|
||||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
|
||||||
multi_process_runner.run(
|
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||||
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
|
args=(self, file_format))
|
||||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
|
||||||
args=(self, file_format))
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['eager']))
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
def test_tensorboard_saves_on_chief_but_not_otherwise(self, mode):
|
def test_tensorboard_saves_on_chief_but_not_otherwise(self, mode):
|
||||||
@ -142,12 +140,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
test_obj.assertEqual(
|
test_obj.assertEqual(
|
||||||
bool(file_io.list_directory(saving_filepath)), test_base.is_chief())
|
bool(file_io.list_directory(saving_filepath)), test_base.is_chief())
|
||||||
|
|
||||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
multi_process_runner.run(
|
||||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
proc_tensorboard_saves_on_chief_but_not_otherwise,
|
||||||
multi_process_runner.run(
|
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||||
proc_tensorboard_saves_on_chief_but_not_otherwise,
|
args=(self,))
|
||||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
|
||||||
args=(self,))
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['eager']))
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
def test_tensorboard_can_still_save_to_temp_even_if_it_exists(self, mode):
|
def test_tensorboard_can_still_save_to_temp_even_if_it_exists(self, mode):
|
||||||
@ -173,12 +169,10 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
|||||||
steps_per_epoch=steps,
|
steps_per_epoch=steps,
|
||||||
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
|
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
|
||||||
|
|
||||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
multi_process_runner.run(
|
||||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
proc_tensorboard_can_still_save_to_temp_even_if_it_exists,
|
||||||
multi_process_runner.run(
|
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||||
proc_tensorboard_can_still_save_to_temp_even_if_it_exists,
|
args=(self,))
|
||||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
|
||||||
args=(self,))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user