Adding barrier message to clear remote executors in order to support catching OutOfRangeErrors.
PiperOrigin-RevId: 293716720 Change-Id: I0768c99baf080f817e0985e188ffe330b3e15dcc
This commit is contained in:
parent
265e1be025
commit
bdba822d97
@ -883,6 +883,15 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||||
|
TF_Status* status) {
|
||||||
|
#if defined(IS_MOBILE_PLATFORM)
|
||||||
|
status->status = tensorflow::Status::OK();
|
||||||
|
#else // !defined(IS_MOBILE_PLATFORM)
|
||||||
|
status->status = ctx->context->ClearRemoteExecutors();
|
||||||
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
|
}
|
||||||
|
|
||||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||||
ctx->context->SetThreadLocalDevicePlacementPolicy(
|
ctx->context->SetThreadLocalDevicePlacementPolicy(
|
||||||
|
@ -434,6 +434,10 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
const char* worker_name,
|
const char* worker_name,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Clear pending streaming requests and error statuses on remote executors.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
// This function will block till the operation that produces `h` has
|
// This function will block till the operation that produces `h` has
|
||||||
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
||||||
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
||||||
|
@ -647,6 +647,36 @@ Status EagerContext::RemoveFunction(const string& func) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status EagerContext::ClearRemoteExecutors() {
|
||||||
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
eager::EnqueueRequest request;
|
||||||
|
request.set_context_id(GetContextId());
|
||||||
|
request.add_queue()->mutable_clear_remote_executor_for_stream();
|
||||||
|
BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
|
||||||
|
|
||||||
|
for (const auto& target : remote_contexts_) {
|
||||||
|
core::RefCountPtr<eager::EagerClient> eager_client;
|
||||||
|
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
|
||||||
|
|
||||||
|
eager::EnqueueResponse* response = new eager::EnqueueResponse();
|
||||||
|
eager_client->StreamingEnqueueAsync(
|
||||||
|
&request, response, [response, target, &counter](const Status& status) {
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Cleared remote executor on " << target
|
||||||
|
<< " with status: " << status.error_message();
|
||||||
|
}
|
||||||
|
delete response;
|
||||||
|
counter.DecrementCount();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// Currently we have to block since it appears that ops sent before the clear
|
||||||
|
// message returns can be cancelled unexpectedly.
|
||||||
|
// TODO(haoyuzhang): Remove the block.
|
||||||
|
counter.Wait();
|
||||||
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
|
core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
|
||||||
Fprint128 cache_key) {
|
Fprint128 cache_key) {
|
||||||
tf_shared_lock l(cache_mu_);
|
tf_shared_lock l(cache_mu_);
|
||||||
|
@ -234,6 +234,9 @@ class EagerContext : public core::RefCounted {
|
|||||||
|
|
||||||
Status RemoveFunction(const string& func);
|
Status RemoveFunction(const string& func);
|
||||||
|
|
||||||
|
// Clear remote executors on all worker targets in `remote_contexts_`.
|
||||||
|
Status ClearRemoteExecutors();
|
||||||
|
|
||||||
core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
|
core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
|
||||||
|
|
||||||
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
|
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
|
||||||
|
@ -430,8 +430,11 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
|
|||||||
s = SendTensor(item.send_tensor(), context->Context());
|
s = SendTensor(item.send_tensor(), context->Context());
|
||||||
} else if (item.has_register_function()) {
|
} else if (item.has_register_function()) {
|
||||||
s = RegisterFunction(item.register_function(), context->Context());
|
s = RegisterFunction(item.register_function(), context->Context());
|
||||||
} else {
|
} else if (item.has_cleanup_function()) {
|
||||||
s = CleanupFunction(item.cleanup_function());
|
s = CleanupFunction(item.cleanup_function());
|
||||||
|
} else {
|
||||||
|
DCHECK(item.has_clear_remote_executor_for_stream());
|
||||||
|
s = executor.WaitForAllPendingNodes();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
|
@ -55,6 +55,10 @@ message QueueItem {
|
|||||||
// Takes a FunctionDef and makes it enqueable on the remote worker.
|
// Takes a FunctionDef and makes it enqueable on the remote worker.
|
||||||
RegisterFunctionOp register_function = 4;
|
RegisterFunctionOp register_function = 4;
|
||||||
CleanupFunctionOp cleanup_function = 5;
|
CleanupFunctionOp cleanup_function = 5;
|
||||||
|
// A remote executor is created to execute ops/functions asynchronously
|
||||||
|
// enqueued in streaming call. Request with this item type clears pending
|
||||||
|
// nodes and status of the executor on the remote worker.
|
||||||
|
ClearRemoteExecutorForStream clear_remote_executor_for_stream = 6;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,6 +196,8 @@ message CleanupFunctionOp {
|
|||||||
int64 step_id = 1;
|
int64 step_id = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message ClearRemoteExecutorForStream {}
|
||||||
|
|
||||||
message SendTensorOp {
|
message SendTensorOp {
|
||||||
// All remote tensors are identified by <Op ID, Output num>. To mimic this
|
// All remote tensors are identified by <Op ID, Output num>. To mimic this
|
||||||
// situation when directly sending tensors, we include an "artificial" op ID
|
// situation when directly sending tensors, we include an "artificial" op ID
|
||||||
|
@ -620,6 +620,22 @@ class Context(object):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Context is not initialized.")
|
raise ValueError("Context is not initialized.")
|
||||||
|
|
||||||
|
def clear_remote_executors(self):
|
||||||
|
"""Clear executors on remote workers.
|
||||||
|
|
||||||
|
After receiving errors from remote workers, additional requests on the fly
|
||||||
|
could further taint the status on the remote workers due to the async nature
|
||||||
|
of remote execution. Calling this method block on waiting for all pending
|
||||||
|
nodes in remote executors to finish and clear their error statuses.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if context is not initialized.
|
||||||
|
"""
|
||||||
|
if self._context_handle:
|
||||||
|
pywrap_tfe.TFE_ContextClearRemoteExecutors(self._context_handle)
|
||||||
|
else:
|
||||||
|
raise ValueError("Context is not initialized.")
|
||||||
|
|
||||||
def enable_collective_ops(self, server_def):
|
def enable_collective_ops(self, server_def):
|
||||||
"""Enable distributed collective ops with an appropriate server_def.
|
"""Enable distributed collective ops with an appropriate server_def.
|
||||||
|
|
||||||
|
@ -467,6 +467,13 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
return output;
|
return output;
|
||||||
});
|
});
|
||||||
|
m.def("TFE_ContextClearRemoteExecutors", [](py::handle& ctx) {
|
||||||
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
|
TFE_ContextClearRemoteExecutors(tensorflow::InputTFE_Context(ctx),
|
||||||
|
status.get());
|
||||||
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
|
});
|
||||||
|
|
||||||
// TFE_Executor logic
|
// TFE_Executor logic
|
||||||
m.def(
|
m.def(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user