diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index c3cbdb1ade3..bba9d7758ef 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -883,6 +883,15 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, #endif // !IS_MOBILE_PLATFORM } +TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx, + TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) + status->status = tensorflow::Status::OK(); +#else // !defined(IS_MOBILE_PLATFORM) + status->status = ctx->context->ClearRemoteExecutors(); +#endif // !IS_MOBILE_PLATFORM +} + void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { ctx->context->SetThreadLocalDevicePlacementPolicy( diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index fe95c238e52..cbd9669a363 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -434,6 +434,10 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, const char* worker_name, TF_Status* status); +// Clear pending streaming requests and error statuses on remote executors. +TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx, + TF_Status* status); + // This function will block till the operation that produces `h` has // completed. This is only valid on local TFE_TensorHandles. The pointer // returned will be on the device in which the TFE_TensorHandle resides (so e.g. diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 4e93b0efaab..5932ed4b698 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -647,6 +647,36 @@ Status EagerContext::RemoveFunction(const string& func) { return Status::OK(); } +Status EagerContext::ClearRemoteExecutors() { +#if !defined(IS_MOBILE_PLATFORM) + eager::EnqueueRequest request; + request.set_context_id(GetContextId()); + request.add_queue()->mutable_clear_remote_executor_for_stream(); + BlockingCounter counter(static_cast(remote_contexts_.size())); + + for (const auto& target : remote_contexts_) { + core::RefCountPtr eager_client; + TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client)); + + eager::EnqueueResponse* response = new eager::EnqueueResponse(); + eager_client->StreamingEnqueueAsync( + &request, response, [response, target, &counter](const Status& status) { + if (!status.ok()) { + LOG(ERROR) << "Cleared remote executor on " << target + << " with status: " << status.error_message(); + } + delete response; + counter.DecrementCount(); + }); + } + // Currently we have to block since it appears that ops sent before the clear + // message returns can be cancelled unexpectedly. + // TODO(haoyuzhang): Remove the block. + counter.Wait(); +#endif // !IS_MOBILE_PLATFORM + return Status::OK(); +} + core::RefCountPtr EagerContext::GetCachedKernel( Fprint128 cache_key) { tf_shared_lock l(cache_mu_); diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 7c964b61e3d..094e7fd8b49 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -234,6 +234,9 @@ class EagerContext : public core::RefCounted { Status RemoveFunction(const string& func); + // Clear remote executors on all worker targets in `remote_contexts_`. + Status ClearRemoteExecutors(); + core::RefCountPtr GetCachedKernel(Fprint128 cache_key); void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index d57aeb77b22..f85e20db084 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -430,8 +430,11 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request, s = SendTensor(item.send_tensor(), context->Context()); } else if (item.has_register_function()) { s = RegisterFunction(item.register_function(), context->Context()); - } else { + } else if (item.has_cleanup_function()) { s = CleanupFunction(item.cleanup_function()); + } else { + DCHECK(item.has_clear_remote_executor_for_stream()); + s = executor.WaitForAllPendingNodes(); } if (!s.ok()) { diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 6a9ad30f1f8..cd2a7c7a24f 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -55,6 +55,10 @@ message QueueItem { // Takes a FunctionDef and makes it enqueable on the remote worker. RegisterFunctionOp register_function = 4; CleanupFunctionOp cleanup_function = 5; + // A remote executor is created to execute ops/functions asynchronously + // enqueued in streaming call. Request with this item type clears pending + // nodes and status of the executor on the remote worker. + ClearRemoteExecutorForStream clear_remote_executor_for_stream = 6; } } @@ -192,6 +196,8 @@ message CleanupFunctionOp { int64 step_id = 1; } +message ClearRemoteExecutorForStream {} + message SendTensorOp { // All remote tensors are identified by . To mimic this // situation when directly sending tensors, we include an "artificial" op ID diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index b0717936aba..e32e71152f0 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -620,6 +620,22 @@ class Context(object): else: raise ValueError("Context is not initialized.") + def clear_remote_executors(self): + """Clear executors on remote workers. + + After receiving errors from remote workers, additional requests on the fly + could further taint the status on the remote workers due to the async nature + of remote execution. Calling this method block on waiting for all pending + nodes in remote executors to finish and clear their error statuses. + + Raises: + ValueError: if context is not initialized. + """ + if self._context_handle: + pywrap_tfe.TFE_ContextClearRemoteExecutors(self._context_handle) + else: + raise ValueError("Context is not initialized.") + def enable_collective_ops(self, server_def): """Enable distributed collective ops with an appropriate server_def. diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 09532732b5a..f9a233cca7a 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -467,6 +467,13 @@ PYBIND11_MODULE(_pywrap_tfe, m) { tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); return output; }); + m.def("TFE_ContextClearRemoteExecutors", [](py::handle& ctx) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + TFE_ContextClearRemoteExecutors(tensorflow::InputTFE_Context(ctx), + status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + }); // TFE_Executor logic m.def(