diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index c1b379f700d..08405edf3d1 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -71,25 +71,42 @@ RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, Device* recv_device, uint64 recv_op_id) : EagerNode(), src_(src), - dst_(dst), ctx_(ctx), executor_(executor), send_device_(src->DeviceOrHostCPU(ctx)), recv_device_(recv_device), wire_id_(GetUniqueWireID()), - recv_op_id_(recv_op_id) { + recv_op_id_(recv_op_id), + captured_state_(std::make_shared(dst)) { DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal()); src_->Ref(); - dst_->Ref(); ctx_->Ref(); } -Status RemoteCopyNode::RunSend() { +Status RemoteCopyNode::RunLocalSend(EagerOperation* op) { + TF_RETURN_IF_ERROR(executor_->status()); + + op->AddInput(src_); + + core::RefCountPtr kernel; + TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel)); + + gtl::InlinedVector input_vector(1); + TF_RETURN_IF_ERROR(src_->TensorValue(&input_vector[0])); + + return kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr, nullptr); +} + +Status RemoteCopyNode::StartSend() { // TODO(gjn): We should consider just using the low-level SendOp::Compute() // functionality here instead of constructing an Op. const AttrTypeMap* types; bool is_function = false; - TF_RETURN_IF_ERROR(AttrTypeMapForOp("_Send", &types, &is_function)); + Status status = AttrTypeMapForOp("_Send", &types, &is_function); + if (!status.ok()) { + captured_state_->SetSendStatus(status); + return status; + } DCHECK(!is_function); EagerOperation op(ctx_, "_Send", /*is_function=*/false, types); @@ -108,52 +125,116 @@ Status RemoteCopyNode::RunSend() { DCHECK(send_device_ != nullptr); if (send_device_->IsLocal()) { - TF_RETURN_IF_ERROR(executor_->status()); - - op.AddInput(src_); - - core::RefCountPtr kernel; - TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(&op, &kernel)); - - gtl::InlinedVector input_vector(1); - TF_RETURN_IF_ERROR(src_->TensorValue(&input_vector[0])); - - TF_RETURN_IF_ERROR( - kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr, nullptr)); + status = RunLocalSend(&op); + captured_state_->SetSendStatus(status); + return status; } else { - eager::EagerClient* eager_client; - uint64 context_id = ctx_->GetContextId(); - TF_RETURN_IF_ERROR(ctx_->GetClient(send_device_, &eager_client)); - - std::unique_ptr request(new eager::EnqueueRequest); - request->set_context_id(context_id); - - auto* remote_op = request->add_queue()->mutable_operation(); - TF_RETURN_IF_ERROR(ctx_->RemoteMgr()->SerializeRemoteTensorHandle( - src_, remote_op->add_inputs(), src_->device())); + // Prepare the request + EnqueueRequest request; + request.set_context_id(ctx_->GetContextId()); + auto* remote_op = request.add_queue()->mutable_operation(); + status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle( + src_, remote_op->add_inputs(), src_->device()); + if (!status.ok()) { + captured_state_->SetSendStatus(status); + return status; + } PrepareRemoteOp(remote_op, &op); remote_op->set_id(ctx_->RemoteMgr()->NextOpId()); - auto* response = new EnqueueResponse; - eager_client->EnqueueAsync(request.get(), response, - [this, response](const Status& s) { - send_status_.Update(s); - if (!s.ok()) { - recv_cancellation_.StartCancel(); - } - delete response; - }); + // Issue the RPC + eager::EagerClient* eager_client; + status = ctx_->GetClient(send_device_, &eager_client); + if (!status.ok()) { + captured_state_->SetSendStatus(status); + return status; + } + + const std::shared_ptr& captured_state = + captured_state_; + EnqueueResponse* response = new EnqueueResponse; + // If StartRecv fails very quickly, `this` can be destroyed before the + // callback below is executed. So, we can't capture `this`. + eager_client->EnqueueAsync( + &request, response, [response, captured_state](const Status& s) { + captured_state->SetSendStatus(s); + if (!s.ok()) { + captured_state->recv_cancellation()->StartCancel(); + } + delete response; + }); + return Status::OK(); } - return Status::OK(); } -Status RemoteCopyNode::RunRecv() { +Status RemoteCopyNode::RunLocalRecv(EagerOperation* op, + std::vector* outputs) { + TF_RETURN_IF_ERROR(executor_->status()); + + core::RefCountPtr kernel; + TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel)); + + gtl::InlinedVector input_vector; + return kernel->Run(input_vector, outputs, nullptr, nullptr, nullptr, + captured_state_->recv_cancellation()); +} + +Status RemoteCopyNode::RunRemoteRecv(EagerOperation* op) { + EnqueueRequest request; + uint64 context_id = ctx_->GetContextId(); + request.set_context_id(context_id); + auto* remote_op = request.add_queue()->mutable_operation(); + PrepareRemoteOp(remote_op, op); + remote_op->set_id(recv_op_id_); + + eager::EagerClient* eager_client; + Status status = ctx_->GetClient(recv_device_, &eager_client); + if (!status.ok()) { + captured_state_->dst()->Poison(status); + return status; + } + + EnqueueResponse* response = new EnqueueResponse; + + // Don't issue the recv until send has completed. + // - local send will complete very quickly. + // - remote send will take some time, but remote->remote copy is + // probably rare enough that we don't care much. + // Blocks until send has completed. + Status send_status = captured_state_->GetSendStatus(); + + const std::shared_ptr& captured_state = captured_state_; + Device* recv_device = recv_device_; + Notification n; + eager_client->EnqueueAsync( + &request, response, + [captured_state, response, recv_device, &n, &status](const Status& s) { + status.Update(s); + if (status.ok()) { + status = captured_state->dst()->SetRemoteShape( + response->queue_response(0).shape(0), recv_device); + } else { + captured_state->dst()->Poison(status); + } + delete response; + n.Notify(); + }); + n.WaitForNotification(); + + return status; +} + +Status RemoteCopyNode::StartRecv() { // TODO(gjn): We should consider just using the low-level RecvOp::Compute() // functionality here instead of constructing an Op. const AttrTypeMap* types; bool is_function = false; - TF_RETURN_IF_ERROR(AttrTypeMapForOp("_Recv", &types, &is_function)); + Status status = AttrTypeMapForOp("_Recv", &types, &is_function); + if (!status.ok()) { + captured_state_->dst()->Poison(status); + return status; + } DCHECK(!is_function); EagerOperation op(ctx_, "_Recv", /*is_function=*/false, types); @@ -170,93 +251,46 @@ Status RemoteCopyNode::RunRecv() { op.MutableAttrs()->Set("tensor_type", src_->dtype); if (recv_device_->IsLocal()) { - TF_RETURN_IF_ERROR(executor_->status()); - - core::RefCountPtr kernel; - TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(&op, &kernel)); - - std::vector outputs; - gtl::InlinedVector input_vector; - TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, nullptr, nullptr, - nullptr, &recv_cancellation_)); - return dst_->SetTensor(outputs[0]); - } else { - eager::EagerClient* eager_client; - uint64 context_id = ctx_->GetContextId(); - TF_RETURN_IF_ERROR(ctx_->GetClient(recv_device_, &eager_client)); - - std::unique_ptr request(new eager::EnqueueRequest); - - request->set_context_id(context_id); - - auto* remote_op = request->add_queue()->mutable_operation(); - PrepareRemoteOp(remote_op, &op); - remote_op->set_id(recv_op_id_); - - EnqueueResponse response; - Status status; - Notification n; - - CancellationToken token = recv_cancellation_.get_cancellation_token(); - bool already_cancelled = - !recv_cancellation_.RegisterCallback(token, [&n, &status] { - status.Update(errors::Cancelled( - "Recv op is cancelled due to an error in Send op.")); - n.Notify(); - }); - - if (already_cancelled) { - status = - errors::Cancelled("Recv op is cancelled due to an error in Send op."); - } else { - // Note(fishx): When the recv op is cancelled, we doesn't clean up the - // state on remote server. So the recv op may ran successfully on the - // remote server even though we cancel it on client. - eager_client->EnqueueAsync(request.get(), &response, - [this, &n, &status](const Status& s) { - if (recv_cancellation_.IsCancelled()) return; - status.Update(s); - n.Notify(); - }); - n.WaitForNotification(); - recv_cancellation_.DeregisterCallback(token); + std::vector outputs(1); + status = RunLocalRecv(&op, &outputs); + if (!status.ok()) { + captured_state_->dst()->Poison(status); + return status; } - - TF_RETURN_IF_ERROR(status); - - return dst_->SetRemoteShape(response.queue_response(0).shape(0), - recv_device_); + return captured_state_->dst()->SetTensor(outputs[0]); + } else { + // Handles captured_state_->dst_ internally. + return RunRemoteRecv(&op); } } Status RemoteCopyNode::Run() { - Status s = RunSend(); + Status s = StartSend(); if (!s.ok()) { Abort(s); return s; } - s = RunRecv(); - if (!s.ok() && errors::IsCancelled(s) && !send_status_.ok()) { - // In this case, Recv is cancel because Send op failed. Return the status of - // send op instead. - Abort(send_status_); - return send_status_; - } - if (!s.ok()) { - Abort(s); + // StartRecv() takes care of doing the right thing to dst handle. + // No need to poison it after this point. + s = StartRecv(); + if (!s.ok() && errors::IsCancelled(s)) { + Status send_status = captured_state_->GetSendStatus(); + if (!send_status.ok()) { + // In this case, Recv is cancelled because the Send op failed. Return the + // status of the Send op instead. + s = send_status; + } } src_->Unref(); - dst_->Unref(); ctx_->Unref(); return s; } void RemoteCopyNode::Abort(Status status) { - dst_->Poison(status); + captured_state_->dst()->Poison(status); src_->Unref(); - dst_->Unref(); ctx_->Unref(); } diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.h b/tensorflow/core/distributed_runtime/eager/remote_copy_node.h index 41bb025b6cb..f6429012d74 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/tensor.h" @@ -26,11 +27,35 @@ limitations under the License. namespace tensorflow { namespace eager { -// This node supports copy a tensor: -// Remote -> Remote -// Local -> Remote -// Remote -> Local -// To copy a tensor with a host, please use copy_to_device_node instead. +// This node supports copying a tensor in the following way: +// - Remote -> Local: +// We don't block on the remote _Send op and start executing the local +// _Recv immediately after issuing the remote _Send. The local _Recv +// kernel (or rather the special _Recv handling in KernelAndDeviceOp::Run) +// blocks until the tensor is received. If the remote _Send (or some op +// before it) fails, the local callback we give to EnqueueAsync will run +// and call CancellationManager.StartCancel(). The blocked local _Recv will +// get this notification and return with a cancelled error. +// +// - Local -> Remote: +// The local _Send op is synchronous and non-blocking, thus it should complete +// quickly. We issue remote _Recv RPC only after local _Send completes +// successfully. At this point, the tensor to be sent is in the local +// Rendezvous, hence, remote _Recv op will not deadlock waiting for the tensor +// to appear. +// +// - Remote -> Remote: +// We could issue both remote ops asynchronously, but if remote _Send (or some +// op before it) fails, we don't have a good way of cancelling the remote +// _Recv. The remote _Recv will deadlock in this case. The current approach +// to deal with this issue is to wait for remote _Send to complete before +// issuing remote _Recv RPC. Another option is to close the whole streaming +// RPC that contains the deadlocked remote _Recv. This would not unblock the +// deadlocked RPC on the remote machine without some extra code. Luckily, the +// remote -> remote case seems to be fairly rare at this point. So, the +// current partially synchronous approach seems fine. +// +// To copy a tensor within a host, please use copy_to_device_node instead. class RemoteCopyNode : public EagerNode { public: RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src, @@ -43,11 +68,69 @@ class RemoteCopyNode : public EagerNode { void Abort(Status status) override; private: - Status RunSend(); - Status RunRecv(); + // Runs the _Send operation locally or remotely. + // An error return value indicates that _Send did not run successfully. + // An OK return value does NOT necessarily indicate that _Send has completed + // successfully. It might still fail after this method returns. + // StartSend() makes sure that captured_state_->send_status_ is set to the + // final _Send status after captured_state->send_done_.WaitForNotification() + // returns. + Status StartSend(); + + // Synchronously runs local send `op` and returns its status. + Status RunLocalSend(EagerOperation* op); + + // Runs the _Recv operation locally or remotely. + // An error return value indicates that _Recv did not run successfully. It + // does not indicate that _Send op has completed since StartRecv could have + // encountered an error before waiting for _Send's completion. + // An OK return value does NOT necessarily indicate that _Recv has completed + // successfully (it does now, but won't when streaming RPCs are turned on). + // StartRecv() makes sure that dst_ tensor handle is handled correctly + // (potentially after this methods returns); a tensor is set in the local + // case, a remote shape is set in the remote case, the dst_ handle is + // poisoned in either case if there is an error. + Status StartRecv(); + + // Synchronously runs local receive `op` and returns its status. + // Does not wait for the send to complete before running receive. + Status RunLocalRecv(EagerOperation* op, std::vector* outputs); + + // Waits for send to complete, then issues remote receive `op` and + // returns its status. + Status RunRemoteRecv(EagerOperation* op); + + // State that is captured by Send and/or Recv callbacks (depending on which + // one(s) is remote) and outlives this node in the case of remote->remote + // copy. + class CapturedSharedState { + public: + explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); } + ~CapturedSharedState() { dst_->Unref(); } + + void SetSendStatus(Status status) { + send_status_.Update(status); + send_done_.Notify(); + } + + Status GetSendStatus() { + send_done_.WaitForNotification(); + return send_status_; + } + + TensorHandle* dst() { return dst_; } + CancellationManager* recv_cancellation() { return &recv_cancellation_; } + + private: + TensorHandle* const dst_; + CancellationManager recv_cancellation_; + // send_status_ is safe to read only after send_done_.WaitForNotification() + // has returned. + Status send_status_; + Notification send_done_; + }; TensorHandle* const src_; - TensorHandle* const dst_; EagerContext* const ctx_; EagerExecutor* const executor_; Device* const send_device_; @@ -55,8 +138,7 @@ class RemoteCopyNode : public EagerNode { const string wire_id_; const uint64 recv_op_id_; - CancellationManager recv_cancellation_; - Status send_status_; + std::shared_ptr captured_state_; }; } // namespace eager