Adding barrier message to clear remote executors in order to support catching OutOfRangeErrors.
PiperOrigin-RevId: 293716720 Change-Id: I0768c99baf080f817e0985e188ffe330b3e15dcc
This commit is contained in:
parent
265e1be025
commit
bdba822d97
@ -883,6 +883,15 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
status->status = ctx->context->ClearRemoteExecutors();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||
ctx->context->SetThreadLocalDevicePlacementPolicy(
|
||||
|
@ -434,6 +434,10 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
const char* worker_name,
|
||||
TF_Status* status);
|
||||
|
||||
// Clear pending streaming requests and error statuses on remote executors.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
||||
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
||||
|
@ -647,6 +647,36 @@ Status EagerContext::RemoveFunction(const string& func) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerContext::ClearRemoteExecutors() {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
eager::EnqueueRequest request;
|
||||
request.set_context_id(GetContextId());
|
||||
request.add_queue()->mutable_clear_remote_executor_for_stream();
|
||||
BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
|
||||
|
||||
for (const auto& target : remote_contexts_) {
|
||||
core::RefCountPtr<eager::EagerClient> eager_client;
|
||||
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
|
||||
|
||||
eager::EnqueueResponse* response = new eager::EnqueueResponse();
|
||||
eager_client->StreamingEnqueueAsync(
|
||||
&request, response, [response, target, &counter](const Status& status) {
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Cleared remote executor on " << target
|
||||
<< " with status: " << status.error_message();
|
||||
}
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
// Currently we have to block since it appears that ops sent before the clear
|
||||
// message returns can be cancelled unexpectedly.
|
||||
// TODO(haoyuzhang): Remove the block.
|
||||
counter.Wait();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
|
||||
Fprint128 cache_key) {
|
||||
tf_shared_lock l(cache_mu_);
|
||||
|
@ -234,6 +234,9 @@ class EagerContext : public core::RefCounted {
|
||||
|
||||
Status RemoveFunction(const string& func);
|
||||
|
||||
// Clear remote executors on all worker targets in `remote_contexts_`.
|
||||
Status ClearRemoteExecutors();
|
||||
|
||||
core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
|
||||
|
||||
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
|
||||
|
@ -430,8 +430,11 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
|
||||
s = SendTensor(item.send_tensor(), context->Context());
|
||||
} else if (item.has_register_function()) {
|
||||
s = RegisterFunction(item.register_function(), context->Context());
|
||||
} else {
|
||||
} else if (item.has_cleanup_function()) {
|
||||
s = CleanupFunction(item.cleanup_function());
|
||||
} else {
|
||||
DCHECK(item.has_clear_remote_executor_for_stream());
|
||||
s = executor.WaitForAllPendingNodes();
|
||||
}
|
||||
|
||||
if (!s.ok()) {
|
||||
|
@ -55,6 +55,10 @@ message QueueItem {
|
||||
// Takes a FunctionDef and makes it enqueable on the remote worker.
|
||||
RegisterFunctionOp register_function = 4;
|
||||
CleanupFunctionOp cleanup_function = 5;
|
||||
// A remote executor is created to execute ops/functions asynchronously
|
||||
// enqueued in streaming call. Request with this item type clears pending
|
||||
// nodes and status of the executor on the remote worker.
|
||||
ClearRemoteExecutorForStream clear_remote_executor_for_stream = 6;
|
||||
}
|
||||
}
|
||||
|
||||
@ -192,6 +196,8 @@ message CleanupFunctionOp {
|
||||
int64 step_id = 1;
|
||||
}
|
||||
|
||||
message ClearRemoteExecutorForStream {}
|
||||
|
||||
message SendTensorOp {
|
||||
// All remote tensors are identified by <Op ID, Output num>. To mimic this
|
||||
// situation when directly sending tensors, we include an "artificial" op ID
|
||||
|
@ -620,6 +620,22 @@ class Context(object):
|
||||
else:
|
||||
raise ValueError("Context is not initialized.")
|
||||
|
||||
def clear_remote_executors(self):
|
||||
"""Clear executors on remote workers.
|
||||
|
||||
After receiving errors from remote workers, additional requests on the fly
|
||||
could further taint the status on the remote workers due to the async nature
|
||||
of remote execution. Calling this method block on waiting for all pending
|
||||
nodes in remote executors to finish and clear their error statuses.
|
||||
|
||||
Raises:
|
||||
ValueError: if context is not initialized.
|
||||
"""
|
||||
if self._context_handle:
|
||||
pywrap_tfe.TFE_ContextClearRemoteExecutors(self._context_handle)
|
||||
else:
|
||||
raise ValueError("Context is not initialized.")
|
||||
|
||||
def enable_collective_ops(self, server_def):
|
||||
"""Enable distributed collective ops with an appropriate server_def.
|
||||
|
||||
|
@ -467,6 +467,13 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
m.def("TFE_ContextClearRemoteExecutors", [](py::handle& ctx) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
TFE_ContextClearRemoteExecutors(tensorflow::InputTFE_Context(ctx),
|
||||
status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
});
|
||||
|
||||
// TFE_Executor logic
|
||||
m.def(
|
||||
|
Loading…
Reference in New Issue
Block a user