PSv2: Remove ParameterServerFailureError now that parameter server failures are coming back from runtime as UnavailableError.
PiperOrigin-RevId: 337933395 Change-Id: Ib6277de3b5d457e73d6efb099724f8d76e001045
This commit is contained in:
parent
135cdcee4b
commit
cca4ca7344
@ -844,11 +844,6 @@ class Cluster(object):
|
||||
return self._closure_queue.done()
|
||||
|
||||
|
||||
class ParameterServerFailureError(Exception):
|
||||
"""An error representing at least one parameter server is interrupted."""
|
||||
pass
|
||||
|
||||
|
||||
class Client(object):
|
||||
"""An object to schedule and orchestrate remote function execution.
|
||||
|
||||
@ -942,7 +937,7 @@ class Client(object):
|
||||
"""
|
||||
# Slot variables are usually created during function tracing time; thus
|
||||
# `schedule` needs to be called within the `strategy.scope()`.
|
||||
with self.strategy.scope(), _translate_parameter_server_failure():
|
||||
with self.strategy.scope():
|
||||
return self.cluster.schedule(fn, args=args, kwargs=kwargs)
|
||||
|
||||
def join(self):
|
||||
@ -964,7 +959,6 @@ class Client(object):
|
||||
scheduled function since the last time an error was thrown or since
|
||||
the beginning of the program.
|
||||
"""
|
||||
with _translate_parameter_server_failure():
|
||||
self.cluster.join()
|
||||
|
||||
def done(self):
|
||||
@ -1064,25 +1058,12 @@ class Client(object):
|
||||
return result
|
||||
|
||||
|
||||
# pylint: disable=missing-function-docstring
|
||||
@contextlib.contextmanager
|
||||
def _translate_parameter_server_failure():
|
||||
try:
|
||||
yield
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
if _is_ps_failure(e):
|
||||
raise ParameterServerFailureError(e)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
# pylint: disable=missing-function-docstring
|
||||
@contextlib.contextmanager
|
||||
def handle_parameter_server_failure():
|
||||
try:
|
||||
with _translate_parameter_server_failure():
|
||||
yield
|
||||
except ParameterServerFailureError as e: # pylint: disable=broad-except
|
||||
except errors.UnavailableError as e: # pylint: disable=broad-except
|
||||
restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE",
|
||||
None)
|
||||
if restart_exit_code is not None:
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolve
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -72,8 +73,7 @@ class ClientMprTest(test.TestCase):
|
||||
# Now the main process can terminate.
|
||||
functions_scheduled_event.set()
|
||||
|
||||
# Verified that join and schedule indeed raise
|
||||
# ParameterServerFailureError.
|
||||
# Verified that join and schedule indeed raise UnavailableError.
|
||||
try:
|
||||
if test_join:
|
||||
ps_client.join()
|
||||
@ -81,7 +81,7 @@ class ClientMprTest(test.TestCase):
|
||||
while ps_client.cluster._closure_queue._error is None:
|
||||
time.sleep(1)
|
||||
ps_client.schedule(worker_fn)
|
||||
except client_lib.ParameterServerFailureError:
|
||||
except errors.UnavailableError:
|
||||
# The following verifies that after PS fails, continue executing
|
||||
# functions on workers should fail and indicate it's PS failure.
|
||||
for worker_id in range(3):
|
||||
@ -101,7 +101,7 @@ class ClientMprTest(test.TestCase):
|
||||
raise RuntimeError("Executing a function after PS fails, should "
|
||||
"result in a PS failure.")
|
||||
|
||||
raise RuntimeError("ParameterServerFailureError supposed to be raised.")
|
||||
raise RuntimeError("UnavailableError supposed to be raised.")
|
||||
|
||||
manager = multi_process_runner.manager()
|
||||
functions_scheduled_event = manager.Event()
|
||||
|
Loading…
Reference in New Issue
Block a user