PSv2: Remove ParameterServerFailureError now that parameter server failures are coming back from runtime as UnavailableError.
PiperOrigin-RevId: 337933395 Change-Id: Ib6277de3b5d457e73d6efb099724f8d76e001045
This commit is contained in:
parent
135cdcee4b
commit
cca4ca7344
@ -844,11 +844,6 @@ class Cluster(object):
|
|||||||
return self._closure_queue.done()
|
return self._closure_queue.done()
|
||||||
|
|
||||||
|
|
||||||
class ParameterServerFailureError(Exception):
|
|
||||||
"""An error representing at least one parameter server is interrupted."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Client(object):
|
class Client(object):
|
||||||
"""An object to schedule and orchestrate remote function execution.
|
"""An object to schedule and orchestrate remote function execution.
|
||||||
|
|
||||||
@ -942,7 +937,7 @@ class Client(object):
|
|||||||
"""
|
"""
|
||||||
# Slot variables are usually created during function tracing time; thus
|
# Slot variables are usually created during function tracing time; thus
|
||||||
# `schedule` needs to be called within the `strategy.scope()`.
|
# `schedule` needs to be called within the `strategy.scope()`.
|
||||||
with self.strategy.scope(), _translate_parameter_server_failure():
|
with self.strategy.scope():
|
||||||
return self.cluster.schedule(fn, args=args, kwargs=kwargs)
|
return self.cluster.schedule(fn, args=args, kwargs=kwargs)
|
||||||
|
|
||||||
def join(self):
|
def join(self):
|
||||||
@ -964,8 +959,7 @@ class Client(object):
|
|||||||
scheduled function since the last time an error was thrown or since
|
scheduled function since the last time an error was thrown or since
|
||||||
the beginning of the program.
|
the beginning of the program.
|
||||||
"""
|
"""
|
||||||
with _translate_parameter_server_failure():
|
self.cluster.join()
|
||||||
self.cluster.join()
|
|
||||||
|
|
||||||
def done(self):
|
def done(self):
|
||||||
"""Returns whether all the scheduled functions have finished execution.
|
"""Returns whether all the scheduled functions have finished execution.
|
||||||
@ -1064,25 +1058,12 @@ class Client(object):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=missing-function-docstring
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def _translate_parameter_server_failure():
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except Exception as e: # pylint: disable=broad-except
|
|
||||||
if _is_ps_failure(e):
|
|
||||||
raise ParameterServerFailureError(e)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=missing-function-docstring
|
# pylint: disable=missing-function-docstring
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def handle_parameter_server_failure():
|
def handle_parameter_server_failure():
|
||||||
try:
|
try:
|
||||||
with _translate_parameter_server_failure():
|
yield
|
||||||
yield
|
except errors.UnavailableError as e: # pylint: disable=broad-except
|
||||||
except ParameterServerFailureError as e: # pylint: disable=broad-except
|
|
||||||
restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE",
|
restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE",
|
||||||
None)
|
None)
|
||||||
if restart_exit_code is not None:
|
if restart_exit_code is not None:
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolve
|
|||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
@ -72,8 +73,7 @@ class ClientMprTest(test.TestCase):
|
|||||||
# Now the main process can terminate.
|
# Now the main process can terminate.
|
||||||
functions_scheduled_event.set()
|
functions_scheduled_event.set()
|
||||||
|
|
||||||
# Verified that join and schedule indeed raise
|
# Verified that join and schedule indeed raise UnavailableError.
|
||||||
# ParameterServerFailureError.
|
|
||||||
try:
|
try:
|
||||||
if test_join:
|
if test_join:
|
||||||
ps_client.join()
|
ps_client.join()
|
||||||
@ -81,7 +81,7 @@ class ClientMprTest(test.TestCase):
|
|||||||
while ps_client.cluster._closure_queue._error is None:
|
while ps_client.cluster._closure_queue._error is None:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
ps_client.schedule(worker_fn)
|
ps_client.schedule(worker_fn)
|
||||||
except client_lib.ParameterServerFailureError:
|
except errors.UnavailableError:
|
||||||
# The following verifies that after PS fails, continue executing
|
# The following verifies that after PS fails, continue executing
|
||||||
# functions on workers should fail and indicate it's PS failure.
|
# functions on workers should fail and indicate it's PS failure.
|
||||||
for worker_id in range(3):
|
for worker_id in range(3):
|
||||||
@ -101,7 +101,7 @@ class ClientMprTest(test.TestCase):
|
|||||||
raise RuntimeError("Executing a function after PS fails, should "
|
raise RuntimeError("Executing a function after PS fails, should "
|
||||||
"result in a PS failure.")
|
"result in a PS failure.")
|
||||||
|
|
||||||
raise RuntimeError("ParameterServerFailureError supposed to be raised.")
|
raise RuntimeError("UnavailableError supposed to be raised.")
|
||||||
|
|
||||||
manager = multi_process_runner.manager()
|
manager = multi_process_runner.manager()
|
||||||
functions_scheduled_event = manager.Event()
|
functions_scheduled_event = manager.Event()
|
||||||
|
Loading…
Reference in New Issue
Block a user