Introduce async_wait and async_clear_error primitives.

Add tests to demonstrate the usage of the primitives in handling exceptions thrown in remote async execution.

PiperOrigin-RevId: 297041596
Change-Id: Ibc9ffa7c5eaaa9b62c6849e815c0c933ff0ec86c
This commit is contained in:
Haoyu Zhang 2020-02-24 22:03:49 -08:00 committed by TensorFlower Gardener
parent 7fbfccbd30
commit 957792181b
9 changed files with 201 additions and 41 deletions

View File

@ -874,12 +874,12 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
#endif // !IS_MOBILE_PLATFORM
}
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
status->status = ctx->context->ClearRemoteExecutors();
status->status = ctx->context->SyncExecutors();
#endif // !IS_MOBILE_PLATFORM
}

View File

@ -382,8 +382,10 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
const char* worker_name,
TF_Status* status);
// Clear pending streaming requests and error statuses on remote executors.
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
// Sync pending nodes in local executors (including the context default executor
// and thread executors) and streaming requests to remote executors, and get the
// combined status.
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
TF_Status* status);
// If the TensorHandle is copied to another device as part of an op execution,

View File

@ -656,34 +656,51 @@ Status EagerContext::RemoveFunction(const string& func) {
return Status::OK();
}
Status EagerContext::ClearRemoteExecutors() {
Status EagerContext::SyncExecutors() {
StatusGroup sg;
// Synchronize on context default executor
sg.Update(default_executor_.WaitForAllPendingNodes());
default_executor_.ClearError();
// Synchronize thread local executors on client
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
{
mutex_lock l(executor_map_mu_);
executors_copy = thread_local_executor_;
}
for (const auto& entry : executors_copy) {
sg.Update(entry.second->WaitForAllPendingNodes());
entry.second->ClearError();
}
#if !defined(IS_MOBILE_PLATFORM)
// Synchronize executors on remote workers
eager::EnqueueRequest request;
request.set_context_id(GetContextId());
request.add_queue()->mutable_clear_remote_executor_for_stream();
request.add_queue()->mutable_sync_remote_executor_for_stream();
BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
std::vector<Status> statuses(remote_contexts_.size());
for (const auto& target : remote_contexts_) {
for (int i = 0; i < remote_contexts_.size(); i++) {
const auto& target = remote_contexts_[i];
core::RefCountPtr<eager::EagerClient> eager_client;
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
eager::EnqueueResponse* response = new eager::EnqueueResponse();
eager_client->StreamingEnqueueAsync(
&request, response, [response, target, &counter](const Status& status) {
if (!status.ok()) {
LOG(ERROR) << "Cleared remote executor on " << target
<< " with status: " << status.error_message();
}
&request, response,
[response, target, &counter, &s = statuses[i]](const Status& status) {
s = status;
delete response;
counter.DecrementCount();
});
}
// Currently we have to block since it appears that ops sent before the clear
// message returns can be cancelled unexpectedly.
// TODO(haoyuzhang): Remove the block.
counter.Wait();
for (const Status& s : statuses) {
sg.Update(s);
}
#endif // !IS_MOBILE_PLATFORM
return Status::OK();
return sg.as_summary_status();
}
core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(

View File

@ -236,8 +236,12 @@ class EagerContext : public core::RefCounted {
Status RemoveFunction(const string& func);
// Clear remote executors on all worker targets in `remote_contexts_`.
Status ClearRemoteExecutors();
// Wait for pending nodes to be finished in local executors (including context
// default executor and thread executors) and executors on remote workers.
// Return combined status of remote executors. If there are multiple errors,
// the Status code will be the same as the first remote executor that has
// errors, and the error message will be combined from all executors.
Status SyncExecutors();
core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);

View File

@ -433,7 +433,7 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
} else if (item.has_cleanup_function()) {
s = CleanupFunction(item.cleanup_function());
} else {
DCHECK(item.has_clear_remote_executor_for_stream());
DCHECK(item.has_sync_remote_executor_for_stream());
s = executor.WaitForAllPendingNodes();
}

View File

@ -56,9 +56,9 @@ message QueueItem {
RegisterFunctionOp register_function = 4;
CleanupFunctionOp cleanup_function = 5;
// A remote executor is created to execute ops/functions asynchronously
// enqueued in streaming call. Request with this item type clears pending
// nodes and status of the executor on the remote worker.
ClearRemoteExecutorForStream clear_remote_executor_for_stream = 6;
// enqueued in streaming call. Request with this item type waits for pending
// nodes to finish on the remote executor and report status.
SyncRemoteExecutorForStream sync_remote_executor_for_stream = 6;
}
}
@ -196,7 +196,7 @@ message CleanupFunctionOp {
int64 step_id = 1;
}
message ClearRemoteExecutorForStream {}
message SyncRemoteExecutorForStream {}
message SendTensorOp {
// All remote tensors are identified by <Op ID, Output num>. To mimic this

View File

@ -621,8 +621,25 @@ class Context(object):
else:
raise ValueError("Context is not initialized.")
def clear_remote_executors(self):
"""Clear executors on remote workers.
def sync_executors(self):
"""Sync both local executors and the ones on remote workers.
In async execution mode, local function calls can return before the
coresponding remote op/function execution requests are completed. Calling
this method creates a synchronization barrier for remote executors. It only
returns when all remote pending nodes are finished, potentially with errors
if any remote executors are in error state.
Raises:
ValueError: if context is not initialized.
"""
if self._context_handle:
pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle)
else:
raise ValueError("Context is not initialized.")
def clear_executor_errors(self):
"""Clear errors in both local executors and remote workers.
After receiving errors from remote workers, additional requests on the fly
could further taint the status on the remote workers due to the async nature
@ -633,7 +650,7 @@ class Context(object):
ValueError: if context is not initialized.
"""
if self._context_handle:
pywrap_tfe.TFE_ContextClearRemoteExecutors(self._context_handle)
pywrap_tfe.TFE_ContextClearExecutors(self._context_handle)
else:
raise ValueError("Context is not initialized.")
@ -2019,16 +2036,6 @@ def is_async():
return context().is_async()
def async_wait():
"""Waits for ops dispatched in ASYNC mode to finish."""
return context().executor.wait()
def async_clear_error():
"""Clears errors raised during ASYNC execution mode."""
return context().executor.clear_error()
def num_gpus():
"""Get the number of available GPU devices.
@ -2135,6 +2142,65 @@ def check_alive(worker_name):
return context().check_alive(worker_name)
def async_wait():
"""Sync all async operations and raise any errors during execution.
In async execution mode, an op/function call can return before finishing the
actual execution. Calling this method creates a synchronization barrier for
all async op and function execution. It only returns when all pending nodes
are finished, potentially raising exceptions if async execution results in
an error state.
Users may write the following code to asynchronuously invoke `train_step_fn`
and log the `loss` metric for every `num_steps` steps in a training loop.
`train_step_fn` internally consumes data using `iterator.get_next()`, and may
throw OutOfRangeError when running out of data. In the case:
- If the exception is thrown during the loop of scheduling function steps,
the next call to function triggers an exception. In the except block,
we clear the error and break from the loop;
- If all `train_step_fn`s are scheduled before throwing an exception, we
block at the last iteration to wait for the scheduled functions to finish
excution and throw the OutOfRangeError.
```
for i in range(num_steps):
try:
# Step function updates the metric `loss` internally
train_step_fn()
if i == num_steps - 1:
context.async_wait()
except tf.errors.OutOfRangeError:
context.async_clear_error()
break
logging.info('loss =', loss.numpy())
```
"""
context().sync_executors()
def async_clear_error():
"""Clear pending operations and error statuses in async execution.
In async execution mode, an error in op/function execution can lead to errors
in subsequent ops/functions that are scheduled but not yet executed. Calling
this method clears all pending operations and reset the async execution state.
Example:
```
while True:
try:
# Step function updates the metric `loss` internally
train_step_fn()
except tf.errors.OutOfRangeError:
context.async_clear_error()
break
logging.info('loss =', loss.numpy())
```
"""
context().clear_executor_errors()
def add_function(fdef):
"""Add a function definition to the context."""
context().add_function(fdef)

View File

@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np
import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -155,6 +156,70 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
self.assertIn('Dimensions must be equal', cm.exception.args[0])
class RemoteAsyncTest(test.TestCase):
def setUp(self):
super(RemoteAsyncTest, self).setUp()
workers, _ = test_util.create_local_cluster(1, 0)
remote.connect_to_remote_host(workers[0].target)
def tearDown(self):
super(RemoteAsyncTest, self).tearDown()
# Reset the context to avoid polluting other test cases.
context._reset_context()
def test_out_of_range_with_while_loop(self):
with ops.device('/job:worker/task:0'):
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
dataset = dataset.batch(1, drop_remainder=False)
iterator = iter(dataset)
v = variables.Variable(1.0)
@def_function.function
def train_step(iterator):
i = next(iterator)
v.assign_add(math_ops.reduce_mean(i))
while True:
try:
with ops.device('/job:worker/task:0'):
train_step(iterator)
except (errors.OutOfRangeError, errors.InternalError):
context.async_clear_error()
break
self.assertAllEqual(v.numpy(), 4.0)
def test_out_of_range_with_for_loop(self):
with ops.device('/job:worker/task:0'):
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
dataset = dataset.batch(1, drop_remainder=False)
iterator = iter(dataset)
v = variables.Variable(1.0)
@def_function.function
def train_step(iterator):
i = next(iterator)
v.assign_add(math_ops.reduce_mean(i))
num_steps = 3
for i in range(num_steps):
try:
with ops.device('/job:worker/task:0'):
train_step(iterator)
if i == num_steps - 1:
context.async_wait()
except errors.OutOfRangeError:
context.async_clear_error()
break
self.assertAllEqual(v.numpy(), 4.0)
class MultiWorkersTest(test.TestCase, parameterized.TestCase):
def setUp(self):

View File

@ -473,13 +473,19 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("TFE_ContextClearRemoteExecutors", [](py::handle& ctx) {
m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextClearRemoteExecutors(tensorflow::InputTFE_Context(ctx),
status.get());
TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextClearExecutors", [](py::handle& ctx) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
// NOTE: different from TFE_ContextSyncExecutors that raises potential
// errors, deliberately ignore executor statuses in cleanup.
});
// TFE_Executor logic
m.def(