diff --git a/tensorflow/python/distribute/client/client.py b/tensorflow/python/distribute/client/client.py index d69d06cc8dc..6eabbfa219a 100644 --- a/tensorflow/python/distribute/client/client.py +++ b/tensorflow/python/distribute/client/client.py @@ -163,13 +163,14 @@ class RemoteValue(object): The remote value, as a numpy data type (if scalar) or ndarray. Raises: - FunctionRetryableError: If the function that produces this `RemoteValue` + tf.errors.CancelledError: If the function that produces this `RemoteValue` is aborted or cancelled due to failure, and the user should handle and reschedule. """ self._status_available_event.wait() if self._status is _RemoteValueStatus.ABORTED: - raise FunctionRetryableError( + raise errors.CancelledError( + None, None, "The corresponding function is aborted. Please reschedule the " "function.") if self._error is not None: @@ -191,11 +192,6 @@ class InputError(Exception): super().__init__(message) -class FunctionRetryableError(Exception): - """An error that represents the closure was aborted and should be retried.""" - pass - - def _maybe_rebuild_remote_values(worker, structure): """Attempts to return errors from `RemoteValue`s. Rebuilds them if needed.""" errors_in_structure = [] @@ -327,9 +323,9 @@ class Closure(object): def _set_output_remote_values_cancelled(self): nest.map_structure( lambda x: x._set_error( # pylint: disable=protected-access,g-long-lambda - FunctionRetryableError("The corresponding function is " - "cancelled. Please reschedule the " - "function.")), + errors.CancelledError( + None, None, "The corresponding function is " + "cancelled. Please reschedule the function.")), self._output_remote_values) # pylint: disable=protected-access def execute_on(self, worker): diff --git a/tensorflow/python/distribute/client/client_test.py b/tensorflow/python/distribute/client/client_test.py index fb68716dc26..981ad964b6d 100644 --- a/tensorflow/python/distribute/client/client_test.py +++ b/tensorflow/python/distribute/client/client_test.py @@ -27,6 +27,7 @@ import time from tensorflow.python.distribute.client import client from tensorflow.python.eager import cancellation from tensorflow.python.eager import def_function +from tensorflow.python.framework import errors from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import coordinator @@ -203,7 +204,7 @@ class CoordinatedClosureQueueTest(test.TestCase): self.assertTrue(closure_queue.done()) with self.assertRaisesRegex( - client.FunctionRetryableError, + errors.CancelledError, 'The corresponding function is cancelled. Please reschedule the ' 'function.'): closure2._fetch_output_remote_values() @@ -225,7 +226,7 @@ class CoordinatedClosureQueueTest(test.TestCase): self.assertTrue(closure_queue.done()) with self.assertRaisesRegex( - client.FunctionRetryableError, + errors.CancelledError, 'The corresponding function is cancelled. Please reschedule the ' 'function.'): closure2._fetch_output_remote_values() @@ -307,7 +308,7 @@ class CoordinatedClosureQueueTest(test.TestCase): # The following asserts that closure3 should have been cancelled. if not call_wait: with self.assertRaisesRegex( - client.FunctionRetryableError, + errors.CancelledError, 'The corresponding function is cancelled. Please reschedule the ' 'function.'): closure3._fetch_output_remote_values() diff --git a/tensorflow/python/distribute/client/parameter_server_client_test.py b/tensorflow/python/distribute/client/parameter_server_client_test.py index 78bd5f76c61..022539308d1 100644 --- a/tensorflow/python/distribute/client/parameter_server_client_test.py +++ b/tensorflow/python/distribute/client/parameter_server_client_test.py @@ -406,7 +406,7 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread): with self.assertRaises(errors.InvalidArgumentError): self.client.join() - with self.assertRaises(client_lib.FunctionRetryableError): + with self.assertRaises(errors.CancelledError): long_function.fetch() for _ in range(3):