Introduce async_wait and async_clear_error primitives.
Add tests to demonstrate the usage of the primitives in handling exceptions thrown in remote async execution. PiperOrigin-RevId: 297041596 Change-Id: Ibc9ffa7c5eaaa9b62c6849e815c0c933ff0ec86c
This commit is contained in:
parent
7fbfccbd30
commit
957792181b
@ -874,12 +874,12 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
#if defined(IS_MOBILE_PLATFORM)
|
#if defined(IS_MOBILE_PLATFORM)
|
||||||
status->status = tensorflow::Status::OK();
|
status->status = tensorflow::Status::OK();
|
||||||
#else // !defined(IS_MOBILE_PLATFORM)
|
#else // !defined(IS_MOBILE_PLATFORM)
|
||||||
status->status = ctx->context->ClearRemoteExecutors();
|
status->status = ctx->context->SyncExecutors();
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -382,8 +382,10 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
const char* worker_name,
|
const char* worker_name,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
// Clear pending streaming requests and error statuses on remote executors.
|
// Sync pending nodes in local executors (including the context default executor
|
||||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
// and thread executors) and streaming requests to remote executors, and get the
|
||||||
|
// combined status.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
// If the TensorHandle is copied to another device as part of an op execution,
|
// If the TensorHandle is copied to another device as part of an op execution,
|
||||||
|
@ -656,34 +656,51 @@ Status EagerContext::RemoveFunction(const string& func) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EagerContext::ClearRemoteExecutors() {
|
Status EagerContext::SyncExecutors() {
|
||||||
|
StatusGroup sg;
|
||||||
|
// Synchronize on context default executor
|
||||||
|
sg.Update(default_executor_.WaitForAllPendingNodes());
|
||||||
|
default_executor_.ClearError();
|
||||||
|
|
||||||
|
// Synchronize thread local executors on client
|
||||||
|
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
|
||||||
|
{
|
||||||
|
mutex_lock l(executor_map_mu_);
|
||||||
|
executors_copy = thread_local_executor_;
|
||||||
|
}
|
||||||
|
for (const auto& entry : executors_copy) {
|
||||||
|
sg.Update(entry.second->WaitForAllPendingNodes());
|
||||||
|
entry.second->ClearError();
|
||||||
|
}
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
// Synchronize executors on remote workers
|
||||||
eager::EnqueueRequest request;
|
eager::EnqueueRequest request;
|
||||||
request.set_context_id(GetContextId());
|
request.set_context_id(GetContextId());
|
||||||
request.add_queue()->mutable_clear_remote_executor_for_stream();
|
request.add_queue()->mutable_sync_remote_executor_for_stream();
|
||||||
BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
|
BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
|
||||||
|
std::vector<Status> statuses(remote_contexts_.size());
|
||||||
|
|
||||||
for (const auto& target : remote_contexts_) {
|
for (int i = 0; i < remote_contexts_.size(); i++) {
|
||||||
|
const auto& target = remote_contexts_[i];
|
||||||
core::RefCountPtr<eager::EagerClient> eager_client;
|
core::RefCountPtr<eager::EagerClient> eager_client;
|
||||||
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
|
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
|
||||||
|
|
||||||
eager::EnqueueResponse* response = new eager::EnqueueResponse();
|
eager::EnqueueResponse* response = new eager::EnqueueResponse();
|
||||||
eager_client->StreamingEnqueueAsync(
|
eager_client->StreamingEnqueueAsync(
|
||||||
&request, response, [response, target, &counter](const Status& status) {
|
&request, response,
|
||||||
if (!status.ok()) {
|
[response, target, &counter, &s = statuses[i]](const Status& status) {
|
||||||
LOG(ERROR) << "Cleared remote executor on " << target
|
s = status;
|
||||||
<< " with status: " << status.error_message();
|
|
||||||
}
|
|
||||||
delete response;
|
delete response;
|
||||||
counter.DecrementCount();
|
counter.DecrementCount();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// Currently we have to block since it appears that ops sent before the clear
|
|
||||||
// message returns can be cancelled unexpectedly.
|
|
||||||
// TODO(haoyuzhang): Remove the block.
|
|
||||||
counter.Wait();
|
counter.Wait();
|
||||||
|
for (const Status& s : statuses) {
|
||||||
|
sg.Update(s);
|
||||||
|
}
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
return Status::OK();
|
return sg.as_summary_status();
|
||||||
}
|
}
|
||||||
|
|
||||||
core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
|
core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
|
||||||
|
@ -236,8 +236,12 @@ class EagerContext : public core::RefCounted {
|
|||||||
|
|
||||||
Status RemoveFunction(const string& func);
|
Status RemoveFunction(const string& func);
|
||||||
|
|
||||||
// Clear remote executors on all worker targets in `remote_contexts_`.
|
// Wait for pending nodes to be finished in local executors (including context
|
||||||
Status ClearRemoteExecutors();
|
// default executor and thread executors) and executors on remote workers.
|
||||||
|
// Return combined status of remote executors. If there are multiple errors,
|
||||||
|
// the Status code will be the same as the first remote executor that has
|
||||||
|
// errors, and the error message will be combined from all executors.
|
||||||
|
Status SyncExecutors();
|
||||||
|
|
||||||
core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
|
core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
|
||||||
|
|
||||||
|
@ -433,7 +433,7 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
|
|||||||
} else if (item.has_cleanup_function()) {
|
} else if (item.has_cleanup_function()) {
|
||||||
s = CleanupFunction(item.cleanup_function());
|
s = CleanupFunction(item.cleanup_function());
|
||||||
} else {
|
} else {
|
||||||
DCHECK(item.has_clear_remote_executor_for_stream());
|
DCHECK(item.has_sync_remote_executor_for_stream());
|
||||||
s = executor.WaitForAllPendingNodes();
|
s = executor.WaitForAllPendingNodes();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,9 +56,9 @@ message QueueItem {
|
|||||||
RegisterFunctionOp register_function = 4;
|
RegisterFunctionOp register_function = 4;
|
||||||
CleanupFunctionOp cleanup_function = 5;
|
CleanupFunctionOp cleanup_function = 5;
|
||||||
// A remote executor is created to execute ops/functions asynchronously
|
// A remote executor is created to execute ops/functions asynchronously
|
||||||
// enqueued in streaming call. Request with this item type clears pending
|
// enqueued in streaming call. Request with this item type waits for pending
|
||||||
// nodes and status of the executor on the remote worker.
|
// nodes to finish on the remote executor and report status.
|
||||||
ClearRemoteExecutorForStream clear_remote_executor_for_stream = 6;
|
SyncRemoteExecutorForStream sync_remote_executor_for_stream = 6;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -196,7 +196,7 @@ message CleanupFunctionOp {
|
|||||||
int64 step_id = 1;
|
int64 step_id = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ClearRemoteExecutorForStream {}
|
message SyncRemoteExecutorForStream {}
|
||||||
|
|
||||||
message SendTensorOp {
|
message SendTensorOp {
|
||||||
// All remote tensors are identified by <Op ID, Output num>. To mimic this
|
// All remote tensors are identified by <Op ID, Output num>. To mimic this
|
||||||
|
@ -621,8 +621,25 @@ class Context(object):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Context is not initialized.")
|
raise ValueError("Context is not initialized.")
|
||||||
|
|
||||||
def clear_remote_executors(self):
|
def sync_executors(self):
|
||||||
"""Clear executors on remote workers.
|
"""Sync both local executors and the ones on remote workers.
|
||||||
|
|
||||||
|
In async execution mode, local function calls can return before the
|
||||||
|
coresponding remote op/function execution requests are completed. Calling
|
||||||
|
this method creates a synchronization barrier for remote executors. It only
|
||||||
|
returns when all remote pending nodes are finished, potentially with errors
|
||||||
|
if any remote executors are in error state.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if context is not initialized.
|
||||||
|
"""
|
||||||
|
if self._context_handle:
|
||||||
|
pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle)
|
||||||
|
else:
|
||||||
|
raise ValueError("Context is not initialized.")
|
||||||
|
|
||||||
|
def clear_executor_errors(self):
|
||||||
|
"""Clear errors in both local executors and remote workers.
|
||||||
|
|
||||||
After receiving errors from remote workers, additional requests on the fly
|
After receiving errors from remote workers, additional requests on the fly
|
||||||
could further taint the status on the remote workers due to the async nature
|
could further taint the status on the remote workers due to the async nature
|
||||||
@ -633,7 +650,7 @@ class Context(object):
|
|||||||
ValueError: if context is not initialized.
|
ValueError: if context is not initialized.
|
||||||
"""
|
"""
|
||||||
if self._context_handle:
|
if self._context_handle:
|
||||||
pywrap_tfe.TFE_ContextClearRemoteExecutors(self._context_handle)
|
pywrap_tfe.TFE_ContextClearExecutors(self._context_handle)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Context is not initialized.")
|
raise ValueError("Context is not initialized.")
|
||||||
|
|
||||||
@ -2019,16 +2036,6 @@ def is_async():
|
|||||||
return context().is_async()
|
return context().is_async()
|
||||||
|
|
||||||
|
|
||||||
def async_wait():
|
|
||||||
"""Waits for ops dispatched in ASYNC mode to finish."""
|
|
||||||
return context().executor.wait()
|
|
||||||
|
|
||||||
|
|
||||||
def async_clear_error():
|
|
||||||
"""Clears errors raised during ASYNC execution mode."""
|
|
||||||
return context().executor.clear_error()
|
|
||||||
|
|
||||||
|
|
||||||
def num_gpus():
|
def num_gpus():
|
||||||
"""Get the number of available GPU devices.
|
"""Get the number of available GPU devices.
|
||||||
|
|
||||||
@ -2135,6 +2142,65 @@ def check_alive(worker_name):
|
|||||||
return context().check_alive(worker_name)
|
return context().check_alive(worker_name)
|
||||||
|
|
||||||
|
|
||||||
|
def async_wait():
|
||||||
|
"""Sync all async operations and raise any errors during execution.
|
||||||
|
|
||||||
|
In async execution mode, an op/function call can return before finishing the
|
||||||
|
actual execution. Calling this method creates a synchronization barrier for
|
||||||
|
all async op and function execution. It only returns when all pending nodes
|
||||||
|
are finished, potentially raising exceptions if async execution results in
|
||||||
|
an error state.
|
||||||
|
|
||||||
|
Users may write the following code to asynchronuously invoke `train_step_fn`
|
||||||
|
and log the `loss` metric for every `num_steps` steps in a training loop.
|
||||||
|
`train_step_fn` internally consumes data using `iterator.get_next()`, and may
|
||||||
|
throw OutOfRangeError when running out of data. In the case:
|
||||||
|
- If the exception is thrown during the loop of scheduling function steps,
|
||||||
|
the next call to function triggers an exception. In the except block,
|
||||||
|
we clear the error and break from the loop;
|
||||||
|
- If all `train_step_fn`s are scheduled before throwing an exception, we
|
||||||
|
block at the last iteration to wait for the scheduled functions to finish
|
||||||
|
excution and throw the OutOfRangeError.
|
||||||
|
|
||||||
|
```
|
||||||
|
for i in range(num_steps):
|
||||||
|
try:
|
||||||
|
# Step function updates the metric `loss` internally
|
||||||
|
train_step_fn()
|
||||||
|
if i == num_steps - 1:
|
||||||
|
context.async_wait()
|
||||||
|
except tf.errors.OutOfRangeError:
|
||||||
|
context.async_clear_error()
|
||||||
|
break
|
||||||
|
logging.info('loss =', loss.numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
context().sync_executors()
|
||||||
|
|
||||||
|
|
||||||
|
def async_clear_error():
|
||||||
|
"""Clear pending operations and error statuses in async execution.
|
||||||
|
|
||||||
|
In async execution mode, an error in op/function execution can lead to errors
|
||||||
|
in subsequent ops/functions that are scheduled but not yet executed. Calling
|
||||||
|
this method clears all pending operations and reset the async execution state.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Step function updates the metric `loss` internally
|
||||||
|
train_step_fn()
|
||||||
|
except tf.errors.OutOfRangeError:
|
||||||
|
context.async_clear_error()
|
||||||
|
break
|
||||||
|
logging.info('loss =', loss.numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
context().clear_executor_errors()
|
||||||
|
|
||||||
|
|
||||||
def add_function(fdef):
|
def add_function(fdef):
|
||||||
"""Add a function definition to the context."""
|
"""Add a function definition to the context."""
|
||||||
context().add_function(fdef)
|
context().add_function(fdef)
|
||||||
|
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -155,6 +156,70 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertIn('Dimensions must be equal', cm.exception.args[0])
|
self.assertIn('Dimensions must be equal', cm.exception.args[0])
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteAsyncTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(RemoteAsyncTest, self).setUp()
|
||||||
|
|
||||||
|
workers, _ = test_util.create_local_cluster(1, 0)
|
||||||
|
remote.connect_to_remote_host(workers[0].target)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super(RemoteAsyncTest, self).tearDown()
|
||||||
|
|
||||||
|
# Reset the context to avoid polluting other test cases.
|
||||||
|
context._reset_context()
|
||||||
|
|
||||||
|
def test_out_of_range_with_while_loop(self):
|
||||||
|
|
||||||
|
with ops.device('/job:worker/task:0'):
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
|
||||||
|
dataset = dataset.batch(1, drop_remainder=False)
|
||||||
|
iterator = iter(dataset)
|
||||||
|
v = variables.Variable(1.0)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def train_step(iterator):
|
||||||
|
i = next(iterator)
|
||||||
|
v.assign_add(math_ops.reduce_mean(i))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with ops.device('/job:worker/task:0'):
|
||||||
|
train_step(iterator)
|
||||||
|
except (errors.OutOfRangeError, errors.InternalError):
|
||||||
|
context.async_clear_error()
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertAllEqual(v.numpy(), 4.0)
|
||||||
|
|
||||||
|
def test_out_of_range_with_for_loop(self):
|
||||||
|
|
||||||
|
with ops.device('/job:worker/task:0'):
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
|
||||||
|
dataset = dataset.batch(1, drop_remainder=False)
|
||||||
|
iterator = iter(dataset)
|
||||||
|
v = variables.Variable(1.0)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def train_step(iterator):
|
||||||
|
i = next(iterator)
|
||||||
|
v.assign_add(math_ops.reduce_mean(i))
|
||||||
|
|
||||||
|
num_steps = 3
|
||||||
|
for i in range(num_steps):
|
||||||
|
try:
|
||||||
|
with ops.device('/job:worker/task:0'):
|
||||||
|
train_step(iterator)
|
||||||
|
if i == num_steps - 1:
|
||||||
|
context.async_wait()
|
||||||
|
except errors.OutOfRangeError:
|
||||||
|
context.async_clear_error()
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertAllEqual(v.numpy(), 4.0)
|
||||||
|
|
||||||
|
|
||||||
class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -473,13 +473,19 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
return output;
|
return output;
|
||||||
});
|
});
|
||||||
m.def("TFE_ContextClearRemoteExecutors", [](py::handle& ctx) {
|
m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) {
|
||||||
tensorflow::Safe_TF_StatusPtr status =
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
tensorflow::make_safe(TF_NewStatus());
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
TFE_ContextClearRemoteExecutors(tensorflow::InputTFE_Context(ctx),
|
TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
|
||||||
status.get());
|
|
||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
});
|
});
|
||||||
|
m.def("TFE_ContextClearExecutors", [](py::handle& ctx) {
|
||||||
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
|
TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
|
||||||
|
// NOTE: different from TFE_ContextSyncExecutors that raises potential
|
||||||
|
// errors, deliberately ignore executor statuses in cleanup.
|
||||||
|
});
|
||||||
|
|
||||||
// TFE_Executor logic
|
// TFE_Executor logic
|
||||||
m.def(
|
m.def(
|
||||||
|
Loading…
Reference in New Issue
Block a user