PSv2: Remove ParameterServerFailureError now that parameter server failures are coming back from runtime as UnavailableError.

PiperOrigin-RevId: 337933395
Change-Id: Ib6277de3b5d457e73d6efb099724f8d76e001045
This commit is contained in:
Rick Chao 2020-10-19 14:25:05 -07:00 committed by TensorFlower Gardener
parent 135cdcee4b
commit cca4ca7344
2 changed files with 8 additions and 27 deletions

View File

@ -844,11 +844,6 @@ class Cluster(object):
return self._closure_queue.done() return self._closure_queue.done()
class ParameterServerFailureError(Exception):
"""An error representing at least one parameter server is interrupted."""
pass
class Client(object): class Client(object):
"""An object to schedule and orchestrate remote function execution. """An object to schedule and orchestrate remote function execution.
@ -942,7 +937,7 @@ class Client(object):
""" """
# Slot variables are usually created during function tracing time; thus # Slot variables are usually created during function tracing time; thus
# `schedule` needs to be called within the `strategy.scope()`. # `schedule` needs to be called within the `strategy.scope()`.
with self.strategy.scope(), _translate_parameter_server_failure(): with self.strategy.scope():
return self.cluster.schedule(fn, args=args, kwargs=kwargs) return self.cluster.schedule(fn, args=args, kwargs=kwargs)
def join(self): def join(self):
@ -964,8 +959,7 @@ class Client(object):
scheduled function since the last time an error was thrown or since scheduled function since the last time an error was thrown or since
the beginning of the program. the beginning of the program.
""" """
with _translate_parameter_server_failure(): self.cluster.join()
self.cluster.join()
def done(self): def done(self):
"""Returns whether all the scheduled functions have finished execution. """Returns whether all the scheduled functions have finished execution.
@ -1064,25 +1058,12 @@ class Client(object):
return result return result
# pylint: disable=missing-function-docstring
@contextlib.contextmanager
def _translate_parameter_server_failure():
try:
yield
except Exception as e: # pylint: disable=broad-except
if _is_ps_failure(e):
raise ParameterServerFailureError(e)
else:
raise
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
@contextlib.contextmanager @contextlib.contextmanager
def handle_parameter_server_failure(): def handle_parameter_server_failure():
try: try:
with _translate_parameter_server_failure(): yield
yield except errors.UnavailableError as e: # pylint: disable=broad-except
except ParameterServerFailureError as e: # pylint: disable=broad-except
restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE", restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE",
None) None)
if restart_exit_code is not None: if restart_exit_code is not None:

View File

@ -31,6 +31,7 @@ from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolve
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
@ -72,8 +73,7 @@ class ClientMprTest(test.TestCase):
# Now the main process can terminate. # Now the main process can terminate.
functions_scheduled_event.set() functions_scheduled_event.set()
# Verified that join and schedule indeed raise # Verified that join and schedule indeed raise UnavailableError.
# ParameterServerFailureError.
try: try:
if test_join: if test_join:
ps_client.join() ps_client.join()
@ -81,7 +81,7 @@ class ClientMprTest(test.TestCase):
while ps_client.cluster._closure_queue._error is None: while ps_client.cluster._closure_queue._error is None:
time.sleep(1) time.sleep(1)
ps_client.schedule(worker_fn) ps_client.schedule(worker_fn)
except client_lib.ParameterServerFailureError: except errors.UnavailableError:
# The following verifies that after PS fails, continue executing # The following verifies that after PS fails, continue executing
# functions on workers should fail and indicate it's PS failure. # functions on workers should fail and indicate it's PS failure.
for worker_id in range(3): for worker_id in range(3):
@ -101,7 +101,7 @@ class ClientMprTest(test.TestCase):
raise RuntimeError("Executing a function after PS fails, should " raise RuntimeError("Executing a function after PS fails, should "
"result in a PS failure.") "result in a PS failure.")
raise RuntimeError("ParameterServerFailureError supposed to be raised.") raise RuntimeError("UnavailableError supposed to be raised.")
manager = multi_process_runner.manager() manager = multi_process_runner.manager()
functions_scheduled_event = manager.Event() functions_scheduled_event = manager.Event()