PSv2: Remove ParameterServerFailureError now that parameter server failures are coming back from runtime as UnavailableError.

PiperOrigin-RevId: 337933395
Change-Id: Ib6277de3b5d457e73d6efb099724f8d76e001045
This commit is contained in:
Rick Chao 2020-10-19 14:25:05 -07:00 committed by TensorFlower Gardener
parent 135cdcee4b
commit cca4ca7344
2 changed files with 8 additions and 27 deletions

View File

@ -844,11 +844,6 @@ class Cluster(object):
return self._closure_queue.done()
class ParameterServerFailureError(Exception):
"""An error representing at least one parameter server is interrupted."""
pass
class Client(object):
"""An object to schedule and orchestrate remote function execution.
@ -942,7 +937,7 @@ class Client(object):
"""
# Slot variables are usually created during function tracing time; thus
# `schedule` needs to be called within the `strategy.scope()`.
with self.strategy.scope(), _translate_parameter_server_failure():
with self.strategy.scope():
return self.cluster.schedule(fn, args=args, kwargs=kwargs)
def join(self):
@ -964,7 +959,6 @@ class Client(object):
scheduled function since the last time an error was thrown or since
the beginning of the program.
"""
with _translate_parameter_server_failure():
self.cluster.join()
def done(self):
@ -1064,25 +1058,12 @@ class Client(object):
return result
# pylint: disable=missing-function-docstring
@contextlib.contextmanager
def _translate_parameter_server_failure():
try:
yield
except Exception as e: # pylint: disable=broad-except
if _is_ps_failure(e):
raise ParameterServerFailureError(e)
else:
raise
# pylint: disable=missing-function-docstring
@contextlib.contextmanager
def handle_parameter_server_failure():
try:
with _translate_parameter_server_failure():
yield
except ParameterServerFailureError as e: # pylint: disable=broad-except
except errors.UnavailableError as e: # pylint: disable=broad-except
restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE",
None)
if restart_exit_code is not None:

View File

@ -31,6 +31,7 @@ from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolve
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
@ -72,8 +73,7 @@ class ClientMprTest(test.TestCase):
# Now the main process can terminate.
functions_scheduled_event.set()
# Verified that join and schedule indeed raise
# ParameterServerFailureError.
# Verified that join and schedule indeed raise UnavailableError.
try:
if test_join:
ps_client.join()
@ -81,7 +81,7 @@ class ClientMprTest(test.TestCase):
while ps_client.cluster._closure_queue._error is None:
time.sleep(1)
ps_client.schedule(worker_fn)
except client_lib.ParameterServerFailureError:
except errors.UnavailableError:
# The following verifies that after PS fails, continue executing
# functions on workers should fail and indicate it's PS failure.
for worker_id in range(3):
@ -101,7 +101,7 @@ class ClientMprTest(test.TestCase):
raise RuntimeError("Executing a function after PS fails, should "
"result in a PS failure.")
raise RuntimeError("ParameterServerFailureError supposed to be raised.")
raise RuntimeError("UnavailableError supposed to be raised.")
manager = multi_process_runner.manager()
functions_scheduled_event = manager.Event()