From cca4ca73449b6b2ea698f13c67ab434d9b1b130b Mon Sep 17 00:00:00 2001 From: Rick Chao Date: Mon, 19 Oct 2020 14:25:05 -0700 Subject: [PATCH] PSv2: Remove ParameterServerFailureError now that parameter server failures are coming back from runtime as UnavailableError. PiperOrigin-RevId: 337933395 Change-Id: Ib6277de3b5d457e73d6efb099724f8d76e001045 --- tensorflow/python/distribute/client/client.py | 27 +++---------------- .../distribute/client/client_mpr_test.py | 8 +++--- 2 files changed, 8 insertions(+), 27 deletions(-) diff --git a/tensorflow/python/distribute/client/client.py b/tensorflow/python/distribute/client/client.py index 6eabbfa219a..be7157c1fea 100644 --- a/tensorflow/python/distribute/client/client.py +++ b/tensorflow/python/distribute/client/client.py @@ -844,11 +844,6 @@ class Cluster(object): return self._closure_queue.done() -class ParameterServerFailureError(Exception): - """An error representing at least one parameter server is interrupted.""" - pass - - class Client(object): """An object to schedule and orchestrate remote function execution. @@ -942,7 +937,7 @@ class Client(object): """ # Slot variables are usually created during function tracing time; thus # `schedule` needs to be called within the `strategy.scope()`. - with self.strategy.scope(), _translate_parameter_server_failure(): + with self.strategy.scope(): return self.cluster.schedule(fn, args=args, kwargs=kwargs) def join(self): @@ -964,8 +959,7 @@ class Client(object): scheduled function since the last time an error was thrown or since the beginning of the program. """ - with _translate_parameter_server_failure(): - self.cluster.join() + self.cluster.join() def done(self): """Returns whether all the scheduled functions have finished execution. @@ -1064,25 +1058,12 @@ class Client(object): return result -# pylint: disable=missing-function-docstring -@contextlib.contextmanager -def _translate_parameter_server_failure(): - try: - yield - except Exception as e: # pylint: disable=broad-except - if _is_ps_failure(e): - raise ParameterServerFailureError(e) - else: - raise - - # pylint: disable=missing-function-docstring @contextlib.contextmanager def handle_parameter_server_failure(): try: - with _translate_parameter_server_failure(): - yield - except ParameterServerFailureError as e: # pylint: disable=broad-except + yield + except errors.UnavailableError as e: # pylint: disable=broad-except restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE", None) if restart_exit_code is not None: diff --git a/tensorflow/python/distribute/client/client_mpr_test.py b/tensorflow/python/distribute/client/client_mpr_test.py index 802b23e87ec..7f66562b61c 100644 --- a/tensorflow/python/distribute/client/client_mpr_test.py +++ b/tensorflow/python/distribute/client/client_mpr_test.py @@ -31,6 +31,7 @@ from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolve from tensorflow.python.eager import def_function from tensorflow.python.eager import test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables @@ -72,8 +73,7 @@ class ClientMprTest(test.TestCase): # Now the main process can terminate. functions_scheduled_event.set() - # Verified that join and schedule indeed raise - # ParameterServerFailureError. + # Verified that join and schedule indeed raise UnavailableError. try: if test_join: ps_client.join() @@ -81,7 +81,7 @@ class ClientMprTest(test.TestCase): while ps_client.cluster._closure_queue._error is None: time.sleep(1) ps_client.schedule(worker_fn) - except client_lib.ParameterServerFailureError: + except errors.UnavailableError: # The following verifies that after PS fails, continue executing # functions on workers should fail and indicate it's PS failure. for worker_id in range(3): @@ -101,7 +101,7 @@ class ClientMprTest(test.TestCase): raise RuntimeError("Executing a function after PS fails, should " "result in a PS failure.") - raise RuntimeError("ParameterServerFailureError supposed to be raised.") + raise RuntimeError("UnavailableError supposed to be raised.") manager = multi_process_runner.manager() functions_scheduled_event = manager.Event()