Introduce async_wait and async_clear_error primitives.
Add tests to demonstrate the usage of the primitives in handling exceptions thrown in remote async execution. PiperOrigin-RevId: 297041596 Change-Id: Ibc9ffa7c5eaaa9b62c6849e815c0c933ff0ec86c
This commit is contained in:
parent
7fbfccbd30
commit
957792181b
@ -874,12 +874,12 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
status->status = ctx->context->ClearRemoteExecutors();
|
||||
status->status = ctx->context->SyncExecutors();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
|
@ -382,9 +382,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
const char* worker_name,
|
||||
TF_Status* status);
|
||||
|
||||
// Clear pending streaming requests and error statuses on remote executors.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
// Sync pending nodes in local executors (including the context default executor
|
||||
// and thread executors) and streaming requests to remote executors, and get the
|
||||
// combined status.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// If the TensorHandle is copied to another device as part of an op execution,
|
||||
// the copy is destroyed after the op has executed. Enabling implicit mirroring
|
||||
|
@ -656,34 +656,51 @@ Status EagerContext::RemoveFunction(const string& func) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerContext::ClearRemoteExecutors() {
|
||||
Status EagerContext::SyncExecutors() {
|
||||
StatusGroup sg;
|
||||
// Synchronize on context default executor
|
||||
sg.Update(default_executor_.WaitForAllPendingNodes());
|
||||
default_executor_.ClearError();
|
||||
|
||||
// Synchronize thread local executors on client
|
||||
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
|
||||
{
|
||||
mutex_lock l(executor_map_mu_);
|
||||
executors_copy = thread_local_executor_;
|
||||
}
|
||||
for (const auto& entry : executors_copy) {
|
||||
sg.Update(entry.second->WaitForAllPendingNodes());
|
||||
entry.second->ClearError();
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// Synchronize executors on remote workers
|
||||
eager::EnqueueRequest request;
|
||||
request.set_context_id(GetContextId());
|
||||
request.add_queue()->mutable_clear_remote_executor_for_stream();
|
||||
request.add_queue()->mutable_sync_remote_executor_for_stream();
|
||||
BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
|
||||
std::vector<Status> statuses(remote_contexts_.size());
|
||||
|
||||
for (const auto& target : remote_contexts_) {
|
||||
for (int i = 0; i < remote_contexts_.size(); i++) {
|
||||
const auto& target = remote_contexts_[i];
|
||||
core::RefCountPtr<eager::EagerClient> eager_client;
|
||||
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
|
||||
|
||||
eager::EnqueueResponse* response = new eager::EnqueueResponse();
|
||||
eager_client->StreamingEnqueueAsync(
|
||||
&request, response, [response, target, &counter](const Status& status) {
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Cleared remote executor on " << target
|
||||
<< " with status: " << status.error_message();
|
||||
}
|
||||
&request, response,
|
||||
[response, target, &counter, &s = statuses[i]](const Status& status) {
|
||||
s = status;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
// Currently we have to block since it appears that ops sent before the clear
|
||||
// message returns can be cancelled unexpectedly.
|
||||
// TODO(haoyuzhang): Remove the block.
|
||||
counter.Wait();
|
||||
for (const Status& s : statuses) {
|
||||
sg.Update(s);
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return Status::OK();
|
||||
return sg.as_summary_status();
|
||||
}
|
||||
|
||||
core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
|
||||
|
@ -236,8 +236,12 @@ class EagerContext : public core::RefCounted {
|
||||
|
||||
Status RemoveFunction(const string& func);
|
||||
|
||||
// Clear remote executors on all worker targets in `remote_contexts_`.
|
||||
Status ClearRemoteExecutors();
|
||||
// Wait for pending nodes to be finished in local executors (including context
|
||||
// default executor and thread executors) and executors on remote workers.
|
||||
// Return combined status of remote executors. If there are multiple errors,
|
||||
// the Status code will be the same as the first remote executor that has
|
||||
// errors, and the error message will be combined from all executors.
|
||||
Status SyncExecutors();
|
||||
|
||||
core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
|
||||
|
||||
|
@ -433,7 +433,7 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
|
||||
} else if (item.has_cleanup_function()) {
|
||||
s = CleanupFunction(item.cleanup_function());
|
||||
} else {
|
||||
DCHECK(item.has_clear_remote_executor_for_stream());
|
||||
DCHECK(item.has_sync_remote_executor_for_stream());
|
||||
s = executor.WaitForAllPendingNodes();
|
||||
}
|
||||
|
||||
|
@ -56,9 +56,9 @@ message QueueItem {
|
||||
RegisterFunctionOp register_function = 4;
|
||||
CleanupFunctionOp cleanup_function = 5;
|
||||
// A remote executor is created to execute ops/functions asynchronously
|
||||
// enqueued in streaming call. Request with this item type clears pending
|
||||
// nodes and status of the executor on the remote worker.
|
||||
ClearRemoteExecutorForStream clear_remote_executor_for_stream = 6;
|
||||
// enqueued in streaming call. Request with this item type waits for pending
|
||||
// nodes to finish on the remote executor and report status.
|
||||
SyncRemoteExecutorForStream sync_remote_executor_for_stream = 6;
|
||||
}
|
||||
}
|
||||
|
||||
@ -196,7 +196,7 @@ message CleanupFunctionOp {
|
||||
int64 step_id = 1;
|
||||
}
|
||||
|
||||
message ClearRemoteExecutorForStream {}
|
||||
message SyncRemoteExecutorForStream {}
|
||||
|
||||
message SendTensorOp {
|
||||
// All remote tensors are identified by <Op ID, Output num>. To mimic this
|
||||
|
@ -621,8 +621,25 @@ class Context(object):
|
||||
else:
|
||||
raise ValueError("Context is not initialized.")
|
||||
|
||||
def clear_remote_executors(self):
|
||||
"""Clear executors on remote workers.
|
||||
def sync_executors(self):
|
||||
"""Sync both local executors and the ones on remote workers.
|
||||
|
||||
In async execution mode, local function calls can return before the
|
||||
coresponding remote op/function execution requests are completed. Calling
|
||||
this method creates a synchronization barrier for remote executors. It only
|
||||
returns when all remote pending nodes are finished, potentially with errors
|
||||
if any remote executors are in error state.
|
||||
|
||||
Raises:
|
||||
ValueError: if context is not initialized.
|
||||
"""
|
||||
if self._context_handle:
|
||||
pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle)
|
||||
else:
|
||||
raise ValueError("Context is not initialized.")
|
||||
|
||||
def clear_executor_errors(self):
|
||||
"""Clear errors in both local executors and remote workers.
|
||||
|
||||
After receiving errors from remote workers, additional requests on the fly
|
||||
could further taint the status on the remote workers due to the async nature
|
||||
@ -633,7 +650,7 @@ class Context(object):
|
||||
ValueError: if context is not initialized.
|
||||
"""
|
||||
if self._context_handle:
|
||||
pywrap_tfe.TFE_ContextClearRemoteExecutors(self._context_handle)
|
||||
pywrap_tfe.TFE_ContextClearExecutors(self._context_handle)
|
||||
else:
|
||||
raise ValueError("Context is not initialized.")
|
||||
|
||||
@ -2019,16 +2036,6 @@ def is_async():
|
||||
return context().is_async()
|
||||
|
||||
|
||||
def async_wait():
|
||||
"""Waits for ops dispatched in ASYNC mode to finish."""
|
||||
return context().executor.wait()
|
||||
|
||||
|
||||
def async_clear_error():
|
||||
"""Clears errors raised during ASYNC execution mode."""
|
||||
return context().executor.clear_error()
|
||||
|
||||
|
||||
def num_gpus():
|
||||
"""Get the number of available GPU devices.
|
||||
|
||||
@ -2135,6 +2142,65 @@ def check_alive(worker_name):
|
||||
return context().check_alive(worker_name)
|
||||
|
||||
|
||||
def async_wait():
|
||||
"""Sync all async operations and raise any errors during execution.
|
||||
|
||||
In async execution mode, an op/function call can return before finishing the
|
||||
actual execution. Calling this method creates a synchronization barrier for
|
||||
all async op and function execution. It only returns when all pending nodes
|
||||
are finished, potentially raising exceptions if async execution results in
|
||||
an error state.
|
||||
|
||||
Users may write the following code to asynchronuously invoke `train_step_fn`
|
||||
and log the `loss` metric for every `num_steps` steps in a training loop.
|
||||
`train_step_fn` internally consumes data using `iterator.get_next()`, and may
|
||||
throw OutOfRangeError when running out of data. In the case:
|
||||
- If the exception is thrown during the loop of scheduling function steps,
|
||||
the next call to function triggers an exception. In the except block,
|
||||
we clear the error and break from the loop;
|
||||
- If all `train_step_fn`s are scheduled before throwing an exception, we
|
||||
block at the last iteration to wait for the scheduled functions to finish
|
||||
excution and throw the OutOfRangeError.
|
||||
|
||||
```
|
||||
for i in range(num_steps):
|
||||
try:
|
||||
# Step function updates the metric `loss` internally
|
||||
train_step_fn()
|
||||
if i == num_steps - 1:
|
||||
context.async_wait()
|
||||
except tf.errors.OutOfRangeError:
|
||||
context.async_clear_error()
|
||||
break
|
||||
logging.info('loss =', loss.numpy())
|
||||
```
|
||||
"""
|
||||
context().sync_executors()
|
||||
|
||||
|
||||
def async_clear_error():
|
||||
"""Clear pending operations and error statuses in async execution.
|
||||
|
||||
In async execution mode, an error in op/function execution can lead to errors
|
||||
in subsequent ops/functions that are scheduled but not yet executed. Calling
|
||||
this method clears all pending operations and reset the async execution state.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
while True:
|
||||
try:
|
||||
# Step function updates the metric `loss` internally
|
||||
train_step_fn()
|
||||
except tf.errors.OutOfRangeError:
|
||||
context.async_clear_error()
|
||||
break
|
||||
logging.info('loss =', loss.numpy())
|
||||
```
|
||||
"""
|
||||
context().clear_executor_errors()
|
||||
|
||||
|
||||
def add_function(fdef):
|
||||
"""Add a function definition to the context."""
|
||||
context().add_function(fdef)
|
||||
|
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -155,6 +156,70 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIn('Dimensions must be equal', cm.exception.args[0])
|
||||
|
||||
|
||||
class RemoteAsyncTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(RemoteAsyncTest, self).setUp()
|
||||
|
||||
workers, _ = test_util.create_local_cluster(1, 0)
|
||||
remote.connect_to_remote_host(workers[0].target)
|
||||
|
||||
def tearDown(self):
|
||||
super(RemoteAsyncTest, self).tearDown()
|
||||
|
||||
# Reset the context to avoid polluting other test cases.
|
||||
context._reset_context()
|
||||
|
||||
def test_out_of_range_with_while_loop(self):
|
||||
|
||||
with ops.device('/job:worker/task:0'):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
|
||||
dataset = dataset.batch(1, drop_remainder=False)
|
||||
iterator = iter(dataset)
|
||||
v = variables.Variable(1.0)
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
i = next(iterator)
|
||||
v.assign_add(math_ops.reduce_mean(i))
|
||||
|
||||
while True:
|
||||
try:
|
||||
with ops.device('/job:worker/task:0'):
|
||||
train_step(iterator)
|
||||
except (errors.OutOfRangeError, errors.InternalError):
|
||||
context.async_clear_error()
|
||||
break
|
||||
|
||||
self.assertAllEqual(v.numpy(), 4.0)
|
||||
|
||||
def test_out_of_range_with_for_loop(self):
|
||||
|
||||
with ops.device('/job:worker/task:0'):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
|
||||
dataset = dataset.batch(1, drop_remainder=False)
|
||||
iterator = iter(dataset)
|
||||
v = variables.Variable(1.0)
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
i = next(iterator)
|
||||
v.assign_add(math_ops.reduce_mean(i))
|
||||
|
||||
num_steps = 3
|
||||
for i in range(num_steps):
|
||||
try:
|
||||
with ops.device('/job:worker/task:0'):
|
||||
train_step(iterator)
|
||||
if i == num_steps - 1:
|
||||
context.async_wait()
|
||||
except errors.OutOfRangeError:
|
||||
context.async_clear_error()
|
||||
break
|
||||
|
||||
self.assertAllEqual(v.numpy(), 4.0)
|
||||
|
||||
|
||||
class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -473,13 +473,19 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
m.def("TFE_ContextClearRemoteExecutors", [](py::handle& ctx) {
|
||||
m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
TFE_ContextClearRemoteExecutors(tensorflow::InputTFE_Context(ctx),
|
||||
status.get());
|
||||
TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
});
|
||||
m.def("TFE_ContextClearExecutors", [](py::handle& ctx) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
|
||||
// NOTE: different from TFE_ContextSyncExecutors that raises potential
|
||||
// errors, deliberately ignore executor statuses in cleanup.
|
||||
});
|
||||
|
||||
// TFE_Executor logic
|
||||
m.def(
|
||||
|
Loading…
Reference in New Issue
Block a user