Make RemoteCopyNode ready for streaming RPCs

This required:
  - Adding CapturedSharedState for the state that must survive after RemoteCopyNode
    is destroyed.
  - Allocating receive response dynamically. This is not needed right now, but will be
    needed with streaming. I would like to keep the streaming change to the minimum.

Also, fix a race condition accessing send_status_.

PiperOrigin-RevId: 260567277
This commit is contained in:
Igor Ganichev 2019-07-29 13:26:52 -07:00 committed by TensorFlower Gardener
parent 654e1a8b1a
commit f7dc3f6961
2 changed files with 233 additions and 117 deletions

View File

@ -71,25 +71,42 @@ RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
Device* recv_device, uint64 recv_op_id)
: EagerNode(),
src_(src),
dst_(dst),
ctx_(ctx),
executor_(executor),
send_device_(src->DeviceOrHostCPU(ctx)),
recv_device_(recv_device),
wire_id_(GetUniqueWireID()),
recv_op_id_(recv_op_id) {
recv_op_id_(recv_op_id),
captured_state_(std::make_shared<CapturedSharedState>(dst)) {
DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal());
src_->Ref();
dst_->Ref();
ctx_->Ref();
}
Status RemoteCopyNode::RunSend() {
Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
TF_RETURN_IF_ERROR(executor_->status());
op->AddInput(src_);
core::RefCountPtr<KernelAndDevice> kernel;
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
gtl::InlinedVector<TensorValue, 4> input_vector(1);
TF_RETURN_IF_ERROR(src_->TensorValue(&input_vector[0]));
return kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr, nullptr);
}
Status RemoteCopyNode::StartSend() {
// TODO(gjn): We should consider just using the low-level SendOp::Compute()
// functionality here instead of constructing an Op.
const AttrTypeMap* types;
bool is_function = false;
TF_RETURN_IF_ERROR(AttrTypeMapForOp("_Send", &types, &is_function));
Status status = AttrTypeMapForOp("_Send", &types, &is_function);
if (!status.ok()) {
captured_state_->SetSendStatus(status);
return status;
}
DCHECK(!is_function);
EagerOperation op(ctx_, "_Send", /*is_function=*/false, types);
@ -108,52 +125,116 @@ Status RemoteCopyNode::RunSend() {
DCHECK(send_device_ != nullptr);
if (send_device_->IsLocal()) {
TF_RETURN_IF_ERROR(executor_->status());
op.AddInput(src_);
core::RefCountPtr<KernelAndDevice> kernel;
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(&op, &kernel));
gtl::InlinedVector<TensorValue, 4> input_vector(1);
TF_RETURN_IF_ERROR(src_->TensorValue(&input_vector[0]));
TF_RETURN_IF_ERROR(
kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr, nullptr));
status = RunLocalSend(&op);
captured_state_->SetSendStatus(status);
return status;
} else {
eager::EagerClient* eager_client;
uint64 context_id = ctx_->GetContextId();
TF_RETURN_IF_ERROR(ctx_->GetClient(send_device_, &eager_client));
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
request->set_context_id(context_id);
auto* remote_op = request->add_queue()->mutable_operation();
TF_RETURN_IF_ERROR(ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
src_, remote_op->add_inputs(), src_->device()));
// Prepare the request
EnqueueRequest request;
request.set_context_id(ctx_->GetContextId());
auto* remote_op = request.add_queue()->mutable_operation();
status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
src_, remote_op->add_inputs(), src_->device());
if (!status.ok()) {
captured_state_->SetSendStatus(status);
return status;
}
PrepareRemoteOp(remote_op, &op);
remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
auto* response = new EnqueueResponse;
eager_client->EnqueueAsync(request.get(), response,
[this, response](const Status& s) {
send_status_.Update(s);
if (!s.ok()) {
recv_cancellation_.StartCancel();
}
delete response;
});
// Issue the RPC
eager::EagerClient* eager_client;
status = ctx_->GetClient(send_device_, &eager_client);
if (!status.ok()) {
captured_state_->SetSendStatus(status);
return status;
}
const std::shared_ptr<CapturedSharedState>& captured_state =
captured_state_;
EnqueueResponse* response = new EnqueueResponse;
// If StartRecv fails very quickly, `this` can be destroyed before the
// callback below is executed. So, we can't capture `this`.
eager_client->EnqueueAsync(
&request, response, [response, captured_state](const Status& s) {
captured_state->SetSendStatus(s);
if (!s.ok()) {
captured_state->recv_cancellation()->StartCancel();
}
delete response;
});
return Status::OK();
}
return Status::OK();
}
Status RemoteCopyNode::RunRecv() {
Status RemoteCopyNode::RunLocalRecv(EagerOperation* op,
std::vector<Tensor>* outputs) {
TF_RETURN_IF_ERROR(executor_->status());
core::RefCountPtr<KernelAndDevice> kernel;
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
gtl::InlinedVector<TensorValue, 4> input_vector;
return kernel->Run(input_vector, outputs, nullptr, nullptr, nullptr,
captured_state_->recv_cancellation());
}
Status RemoteCopyNode::RunRemoteRecv(EagerOperation* op) {
EnqueueRequest request;
uint64 context_id = ctx_->GetContextId();
request.set_context_id(context_id);
auto* remote_op = request.add_queue()->mutable_operation();
PrepareRemoteOp(remote_op, op);
remote_op->set_id(recv_op_id_);
eager::EagerClient* eager_client;
Status status = ctx_->GetClient(recv_device_, &eager_client);
if (!status.ok()) {
captured_state_->dst()->Poison(status);
return status;
}
EnqueueResponse* response = new EnqueueResponse;
// Don't issue the recv until send has completed.
// - local send will complete very quickly.
// - remote send will take some time, but remote->remote copy is
// probably rare enough that we don't care much.
// Blocks until send has completed.
Status send_status = captured_state_->GetSendStatus();
const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
Device* recv_device = recv_device_;
Notification n;
eager_client->EnqueueAsync(
&request, response,
[captured_state, response, recv_device, &n, &status](const Status& s) {
status.Update(s);
if (status.ok()) {
status = captured_state->dst()->SetRemoteShape(
response->queue_response(0).shape(0), recv_device);
} else {
captured_state->dst()->Poison(status);
}
delete response;
n.Notify();
});
n.WaitForNotification();
return status;
}
Status RemoteCopyNode::StartRecv() {
// TODO(gjn): We should consider just using the low-level RecvOp::Compute()
// functionality here instead of constructing an Op.
const AttrTypeMap* types;
bool is_function = false;
TF_RETURN_IF_ERROR(AttrTypeMapForOp("_Recv", &types, &is_function));
Status status = AttrTypeMapForOp("_Recv", &types, &is_function);
if (!status.ok()) {
captured_state_->dst()->Poison(status);
return status;
}
DCHECK(!is_function);
EagerOperation op(ctx_, "_Recv", /*is_function=*/false, types);
@ -170,93 +251,46 @@ Status RemoteCopyNode::RunRecv() {
op.MutableAttrs()->Set("tensor_type", src_->dtype);
if (recv_device_->IsLocal()) {
TF_RETURN_IF_ERROR(executor_->status());
core::RefCountPtr<KernelAndDevice> kernel;
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(&op, &kernel));
std::vector<Tensor> outputs;
gtl::InlinedVector<TensorValue, 4> input_vector;
TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, nullptr, nullptr,
nullptr, &recv_cancellation_));
return dst_->SetTensor(outputs[0]);
} else {
eager::EagerClient* eager_client;
uint64 context_id = ctx_->GetContextId();
TF_RETURN_IF_ERROR(ctx_->GetClient(recv_device_, &eager_client));
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
request->set_context_id(context_id);
auto* remote_op = request->add_queue()->mutable_operation();
PrepareRemoteOp(remote_op, &op);
remote_op->set_id(recv_op_id_);
EnqueueResponse response;
Status status;
Notification n;
CancellationToken token = recv_cancellation_.get_cancellation_token();
bool already_cancelled =
!recv_cancellation_.RegisterCallback(token, [&n, &status] {
status.Update(errors::Cancelled(
"Recv op is cancelled due to an error in Send op."));
n.Notify();
});
if (already_cancelled) {
status =
errors::Cancelled("Recv op is cancelled due to an error in Send op.");
} else {
// Note(fishx): When the recv op is cancelled, we doesn't clean up the
// state on remote server. So the recv op may ran successfully on the
// remote server even though we cancel it on client.
eager_client->EnqueueAsync(request.get(), &response,
[this, &n, &status](const Status& s) {
if (recv_cancellation_.IsCancelled()) return;
status.Update(s);
n.Notify();
});
n.WaitForNotification();
recv_cancellation_.DeregisterCallback(token);
std::vector<Tensor> outputs(1);
status = RunLocalRecv(&op, &outputs);
if (!status.ok()) {
captured_state_->dst()->Poison(status);
return status;
}
TF_RETURN_IF_ERROR(status);
return dst_->SetRemoteShape(response.queue_response(0).shape(0),
recv_device_);
return captured_state_->dst()->SetTensor(outputs[0]);
} else {
// Handles captured_state_->dst_ internally.
return RunRemoteRecv(&op);
}
}
Status RemoteCopyNode::Run() {
Status s = RunSend();
Status s = StartSend();
if (!s.ok()) {
Abort(s);
return s;
}
s = RunRecv();
if (!s.ok() && errors::IsCancelled(s) && !send_status_.ok()) {
// In this case, Recv is cancel because Send op failed. Return the status of
// send op instead.
Abort(send_status_);
return send_status_;
}
if (!s.ok()) {
Abort(s);
// StartRecv() takes care of doing the right thing to dst handle.
// No need to poison it after this point.
s = StartRecv();
if (!s.ok() && errors::IsCancelled(s)) {
Status send_status = captured_state_->GetSendStatus();
if (!send_status.ok()) {
// In this case, Recv is cancelled because the Send op failed. Return the
// status of the Send op instead.
s = send_status;
}
}
src_->Unref();
dst_->Unref();
ctx_->Unref();
return s;
}
void RemoteCopyNode::Abort(Status status) {
dst_->Poison(status);
captured_state_->dst()->Poison(status);
src_->Unref();
dst_->Unref();
ctx_->Unref();
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/tensor.h"
@ -26,11 +27,35 @@ limitations under the License.
namespace tensorflow {
namespace eager {
// This node supports copy a tensor:
// Remote -> Remote
// Local -> Remote
// Remote -> Local
// To copy a tensor with a host, please use copy_to_device_node instead.
// This node supports copying a tensor in the following way:
// - Remote -> Local:
// We don't block on the remote _Send op and start executing the local
// _Recv immediately after issuing the remote _Send. The local _Recv
// kernel (or rather the special _Recv handling in KernelAndDeviceOp::Run)
// blocks until the tensor is received. If the remote _Send (or some op
// before it) fails, the local callback we give to EnqueueAsync will run
// and call CancellationManager.StartCancel(). The blocked local _Recv will
// get this notification and return with a cancelled error.
//
// - Local -> Remote:
// The local _Send op is synchronous and non-blocking, thus it should complete
// quickly. We issue remote _Recv RPC only after local _Send completes
// successfully. At this point, the tensor to be sent is in the local
// Rendezvous, hence, remote _Recv op will not deadlock waiting for the tensor
// to appear.
//
// - Remote -> Remote:
// We could issue both remote ops asynchronously, but if remote _Send (or some
// op before it) fails, we don't have a good way of cancelling the remote
// _Recv. The remote _Recv will deadlock in this case. The current approach
// to deal with this issue is to wait for remote _Send to complete before
// issuing remote _Recv RPC. Another option is to close the whole streaming
// RPC that contains the deadlocked remote _Recv. This would not unblock the
// deadlocked RPC on the remote machine without some extra code. Luckily, the
// remote -> remote case seems to be fairly rare at this point. So, the
// current partially synchronous approach seems fine.
//
// To copy a tensor within a host, please use copy_to_device_node instead.
class RemoteCopyNode : public EagerNode {
public:
RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src,
@ -43,11 +68,69 @@ class RemoteCopyNode : public EagerNode {
void Abort(Status status) override;
private:
Status RunSend();
Status RunRecv();
// Runs the _Send operation locally or remotely.
// An error return value indicates that _Send did not run successfully.
// An OK return value does NOT necessarily indicate that _Send has completed
// successfully. It might still fail after this method returns.
// StartSend() makes sure that captured_state_->send_status_ is set to the
// final _Send status after captured_state->send_done_.WaitForNotification()
// returns.
Status StartSend();
// Synchronously runs local send `op` and returns its status.
Status RunLocalSend(EagerOperation* op);
// Runs the _Recv operation locally or remotely.
// An error return value indicates that _Recv did not run successfully. It
// does not indicate that _Send op has completed since StartRecv could have
// encountered an error before waiting for _Send's completion.
// An OK return value does NOT necessarily indicate that _Recv has completed
// successfully (it does now, but won't when streaming RPCs are turned on).
// StartRecv() makes sure that dst_ tensor handle is handled correctly
// (potentially after this methods returns); a tensor is set in the local
// case, a remote shape is set in the remote case, the dst_ handle is
// poisoned in either case if there is an error.
Status StartRecv();
// Synchronously runs local receive `op` and returns its status.
// Does not wait for the send to complete before running receive.
Status RunLocalRecv(EagerOperation* op, std::vector<Tensor>* outputs);
// Waits for send to complete, then issues remote receive `op` and
// returns its status.
Status RunRemoteRecv(EagerOperation* op);
// State that is captured by Send and/or Recv callbacks (depending on which
// one(s) is remote) and outlives this node in the case of remote->remote
// copy.
class CapturedSharedState {
public:
explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); }
~CapturedSharedState() { dst_->Unref(); }
void SetSendStatus(Status status) {
send_status_.Update(status);
send_done_.Notify();
}
Status GetSendStatus() {
send_done_.WaitForNotification();
return send_status_;
}
TensorHandle* dst() { return dst_; }
CancellationManager* recv_cancellation() { return &recv_cancellation_; }
private:
TensorHandle* const dst_;
CancellationManager recv_cancellation_;
// send_status_ is safe to read only after send_done_.WaitForNotification()
// has returned.
Status send_status_;
Notification send_done_;
};
TensorHandle* const src_;
TensorHandle* const dst_;
EagerContext* const ctx_;
EagerExecutor* const executor_;
Device* const send_device_;
@ -55,8 +138,7 @@ class RemoteCopyNode : public EagerNode {
const string wire_id_;
const uint64 recv_op_id_;
CancellationManager recv_cancellation_;
Status send_status_;
std::shared_ptr<CapturedSharedState> captured_state_;
};
} // namespace eager