PSv2: Replace FunctionRetryableError
with tf.errors.CancelledError
as they
should deserve the same user treatment. PiperOrigin-RevId: 335677674 Change-Id: I9824aaba7c40fafa9604950165903f9fcb072210
This commit is contained in:
parent
d7bb5785ce
commit
9340214eef
@ -163,13 +163,14 @@ class RemoteValue(object):
|
||||
The remote value, as a numpy data type (if scalar) or ndarray.
|
||||
|
||||
Raises:
|
||||
FunctionRetryableError: If the function that produces this `RemoteValue`
|
||||
tf.errors.CancelledError: If the function that produces this `RemoteValue`
|
||||
is aborted or cancelled due to failure, and the user should handle and
|
||||
reschedule.
|
||||
"""
|
||||
self._status_available_event.wait()
|
||||
if self._status is _RemoteValueStatus.ABORTED:
|
||||
raise FunctionRetryableError(
|
||||
raise errors.CancelledError(
|
||||
None, None,
|
||||
"The corresponding function is aborted. Please reschedule the "
|
||||
"function.")
|
||||
if self._error is not None:
|
||||
@ -191,11 +192,6 @@ class InputError(Exception):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class FunctionRetryableError(Exception):
|
||||
"""An error that represents the closure was aborted and should be retried."""
|
||||
pass
|
||||
|
||||
|
||||
def _maybe_rebuild_remote_values(worker, structure):
|
||||
"""Attempts to return errors from `RemoteValue`s. Rebuilds them if needed."""
|
||||
errors_in_structure = []
|
||||
@ -327,9 +323,9 @@ class Closure(object):
|
||||
def _set_output_remote_values_cancelled(self):
|
||||
nest.map_structure(
|
||||
lambda x: x._set_error( # pylint: disable=protected-access,g-long-lambda
|
||||
FunctionRetryableError("The corresponding function is "
|
||||
"cancelled. Please reschedule the "
|
||||
"function.")),
|
||||
errors.CancelledError(
|
||||
None, None, "The corresponding function is "
|
||||
"cancelled. Please reschedule the function.")),
|
||||
self._output_remote_values) # pylint: disable=protected-access
|
||||
|
||||
def execute_on(self, worker):
|
||||
|
@ -27,6 +27,7 @@ import time
|
||||
from tensorflow.python.distribute.client import client
|
||||
from tensorflow.python.eager import cancellation
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import coordinator
|
||||
@ -203,7 +204,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
self.assertTrue(closure_queue.done())
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
client.FunctionRetryableError,
|
||||
errors.CancelledError,
|
||||
'The corresponding function is cancelled. Please reschedule the '
|
||||
'function.'):
|
||||
closure2._fetch_output_remote_values()
|
||||
@ -225,7 +226,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
self.assertTrue(closure_queue.done())
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
client.FunctionRetryableError,
|
||||
errors.CancelledError,
|
||||
'The corresponding function is cancelled. Please reschedule the '
|
||||
'function.'):
|
||||
closure2._fetch_output_remote_values()
|
||||
@ -307,7 +308,7 @@ class CoordinatedClosureQueueTest(test.TestCase):
|
||||
# The following asserts that closure3 should have been cancelled.
|
||||
if not call_wait:
|
||||
with self.assertRaisesRegex(
|
||||
client.FunctionRetryableError,
|
||||
errors.CancelledError,
|
||||
'The corresponding function is cancelled. Please reschedule the '
|
||||
'function.'):
|
||||
closure3._fetch_output_remote_values()
|
||||
|
@ -406,7 +406,7 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
|
||||
with self.assertRaises(client_lib.FunctionRetryableError):
|
||||
with self.assertRaises(errors.CancelledError):
|
||||
long_function.fetch()
|
||||
|
||||
for _ in range(3):
|
||||
|
Loading…
x
Reference in New Issue
Block a user