Adding barrier message to clear remote executors in order to support catching OutOfRangeErrors.

PiperOrigin-RevId: 293716720
Change-Id: I0768c99baf080f817e0985e188ffe330b3e15dcc
This commit is contained in:
Haoyu Zhang 2020-02-06 17:41:14 -08:00 committed by TensorFlower Gardener
parent 265e1be025
commit bdba822d97
8 changed files with 79 additions and 1 deletions

View File

@ -883,6 +883,15 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
#endif // !IS_MOBILE_PLATFORM
}
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
status->status = ctx->context->ClearRemoteExecutors();
#endif // !IS_MOBILE_PLATFORM
}
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context->SetThreadLocalDevicePlacementPolicy(

View File

@ -434,6 +434,10 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
const char* worker_name,
TF_Status* status);
// Clear pending streaming requests and error statuses on remote executors.
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
TF_Status* status);
// This function will block till the operation that produces `h` has
// completed. This is only valid on local TFE_TensorHandles. The pointer
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.

View File

@ -647,6 +647,36 @@ Status EagerContext::RemoveFunction(const string& func) {
return Status::OK();
}
Status EagerContext::ClearRemoteExecutors() {
#if !defined(IS_MOBILE_PLATFORM)
eager::EnqueueRequest request;
request.set_context_id(GetContextId());
request.add_queue()->mutable_clear_remote_executor_for_stream();
BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
for (const auto& target : remote_contexts_) {
core::RefCountPtr<eager::EagerClient> eager_client;
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
eager::EnqueueResponse* response = new eager::EnqueueResponse();
eager_client->StreamingEnqueueAsync(
&request, response, [response, target, &counter](const Status& status) {
if (!status.ok()) {
LOG(ERROR) << "Cleared remote executor on " << target
<< " with status: " << status.error_message();
}
delete response;
counter.DecrementCount();
});
}
// Currently we have to block since it appears that ops sent before the clear
// message returns can be cancelled unexpectedly.
// TODO(haoyuzhang): Remove the block.
counter.Wait();
#endif // !IS_MOBILE_PLATFORM
return Status::OK();
}
core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
Fprint128 cache_key) {
tf_shared_lock l(cache_mu_);

View File

@ -234,6 +234,9 @@ class EagerContext : public core::RefCounted {
Status RemoveFunction(const string& func);
// Clear remote executors on all worker targets in `remote_contexts_`.
Status ClearRemoteExecutors();
core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);

View File

@ -430,8 +430,11 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
s = SendTensor(item.send_tensor(), context->Context());
} else if (item.has_register_function()) {
s = RegisterFunction(item.register_function(), context->Context());
} else {
} else if (item.has_cleanup_function()) {
s = CleanupFunction(item.cleanup_function());
} else {
DCHECK(item.has_clear_remote_executor_for_stream());
s = executor.WaitForAllPendingNodes();
}
if (!s.ok()) {

View File

@ -55,6 +55,10 @@ message QueueItem {
// Takes a FunctionDef and makes it enqueable on the remote worker.
RegisterFunctionOp register_function = 4;
CleanupFunctionOp cleanup_function = 5;
// A remote executor is created to execute ops/functions asynchronously
// enqueued in streaming call. Request with this item type clears pending
// nodes and status of the executor on the remote worker.
ClearRemoteExecutorForStream clear_remote_executor_for_stream = 6;
}
}
@ -192,6 +196,8 @@ message CleanupFunctionOp {
int64 step_id = 1;
}
message ClearRemoteExecutorForStream {}
message SendTensorOp {
// All remote tensors are identified by <Op ID, Output num>. To mimic this
// situation when directly sending tensors, we include an "artificial" op ID

View File

@ -620,6 +620,22 @@ class Context(object):
else:
raise ValueError("Context is not initialized.")
def clear_remote_executors(self):
"""Clear executors on remote workers.
After receiving errors from remote workers, additional requests on the fly
could further taint the status on the remote workers due to the async nature
of remote execution. Calling this method block on waiting for all pending
nodes in remote executors to finish and clear their error statuses.
Raises:
ValueError: if context is not initialized.
"""
if self._context_handle:
pywrap_tfe.TFE_ContextClearRemoteExecutors(self._context_handle)
else:
raise ValueError("Context is not initialized.")
def enable_collective_ops(self, server_def):
"""Enable distributed collective ops with an appropriate server_def.

View File

@ -467,6 +467,13 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("TFE_ContextClearRemoteExecutors", [](py::handle& ctx) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextClearRemoteExecutors(tensorflow::InputTFE_Context(ctx),
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
// TFE_Executor logic
m.def(