PSv2: Replace FunctionRetryableError with tf.errors.CancelledError as they

should deserve the same user treatment.

PiperOrigin-RevId: 335677674
Change-Id: I9824aaba7c40fafa9604950165903f9fcb072210
This commit is contained in:
Rick Chao 2020-10-06 11:05:22 -07:00 committed by TensorFlower Gardener
parent d7bb5785ce
commit 9340214eef
3 changed files with 11 additions and 14 deletions

View File

@ -163,13 +163,14 @@ class RemoteValue(object):
The remote value, as a numpy data type (if scalar) or ndarray.
Raises:
FunctionRetryableError: If the function that produces this `RemoteValue`
tf.errors.CancelledError: If the function that produces this `RemoteValue`
is aborted or cancelled due to failure, and the user should handle and
reschedule.
"""
self._status_available_event.wait()
if self._status is _RemoteValueStatus.ABORTED:
raise FunctionRetryableError(
raise errors.CancelledError(
None, None,
"The corresponding function is aborted. Please reschedule the "
"function.")
if self._error is not None:
@ -191,11 +192,6 @@ class InputError(Exception):
super().__init__(message)
class FunctionRetryableError(Exception):
"""An error that represents the closure was aborted and should be retried."""
pass
def _maybe_rebuild_remote_values(worker, structure):
"""Attempts to return errors from `RemoteValue`s. Rebuilds them if needed."""
errors_in_structure = []
@ -327,9 +323,9 @@ class Closure(object):
def _set_output_remote_values_cancelled(self):
nest.map_structure(
lambda x: x._set_error( # pylint: disable=protected-access,g-long-lambda
FunctionRetryableError("The corresponding function is "
"cancelled. Please reschedule the "
"function.")),
errors.CancelledError(
None, None, "The corresponding function is "
"cancelled. Please reschedule the function.")),
self._output_remote_values) # pylint: disable=protected-access
def execute_on(self, worker):

View File

@ -27,6 +27,7 @@ import time
from tensorflow.python.distribute.client import client
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import def_function
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator
@ -203,7 +204,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
self.assertTrue(closure_queue.done())
with self.assertRaisesRegex(
client.FunctionRetryableError,
errors.CancelledError,
'The corresponding function is cancelled. Please reschedule the '
'function.'):
closure2._fetch_output_remote_values()
@ -225,7 +226,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
self.assertTrue(closure_queue.done())
with self.assertRaisesRegex(
client.FunctionRetryableError,
errors.CancelledError,
'The corresponding function is cancelled. Please reschedule the '
'function.'):
closure2._fetch_output_remote_values()
@ -307,7 +308,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
# The following asserts that closure3 should have been cancelled.
if not call_wait:
with self.assertRaisesRegex(
client.FunctionRetryableError,
errors.CancelledError,
'The corresponding function is cancelled. Please reschedule the '
'function.'):
closure3._fetch_output_remote_values()

View File

@ -406,7 +406,7 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
with self.assertRaises(client_lib.FunctionRetryableError):
with self.assertRaises(errors.CancelledError):
long_function.fetch()
for _ in range(3):