Make RemoteCopyNode ready for streaming RPCs
This required:
- Adding CapturedSharedState for the state that must survive after RemoteCopyNode
is destroyed.
- Allocating receive response dynamically. This is not needed right now, but will be
needed with streaming. I would like to keep the streaming change to the minimum.
Also, fix a race condition accessing send_status_.
PiperOrigin-RevId: 260567277
This commit is contained in:
parent
654e1a8b1a
commit
f7dc3f6961
@ -71,25 +71,42 @@ RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
|
||||
Device* recv_device, uint64 recv_op_id)
|
||||
: EagerNode(),
|
||||
src_(src),
|
||||
dst_(dst),
|
||||
ctx_(ctx),
|
||||
executor_(executor),
|
||||
send_device_(src->DeviceOrHostCPU(ctx)),
|
||||
recv_device_(recv_device),
|
||||
wire_id_(GetUniqueWireID()),
|
||||
recv_op_id_(recv_op_id) {
|
||||
recv_op_id_(recv_op_id),
|
||||
captured_state_(std::make_shared<CapturedSharedState>(dst)) {
|
||||
DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal());
|
||||
src_->Ref();
|
||||
dst_->Ref();
|
||||
ctx_->Ref();
|
||||
}
|
||||
|
||||
Status RemoteCopyNode::RunSend() {
|
||||
Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
|
||||
TF_RETURN_IF_ERROR(executor_->status());
|
||||
|
||||
op->AddInput(src_);
|
||||
|
||||
core::RefCountPtr<KernelAndDevice> kernel;
|
||||
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> input_vector(1);
|
||||
TF_RETURN_IF_ERROR(src_->TensorValue(&input_vector[0]));
|
||||
|
||||
return kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr, nullptr);
|
||||
}
|
||||
|
||||
Status RemoteCopyNode::StartSend() {
|
||||
// TODO(gjn): We should consider just using the low-level SendOp::Compute()
|
||||
// functionality here instead of constructing an Op.
|
||||
const AttrTypeMap* types;
|
||||
bool is_function = false;
|
||||
TF_RETURN_IF_ERROR(AttrTypeMapForOp("_Send", &types, &is_function));
|
||||
Status status = AttrTypeMapForOp("_Send", &types, &is_function);
|
||||
if (!status.ok()) {
|
||||
captured_state_->SetSendStatus(status);
|
||||
return status;
|
||||
}
|
||||
DCHECK(!is_function);
|
||||
EagerOperation op(ctx_, "_Send", /*is_function=*/false, types);
|
||||
|
||||
@ -108,52 +125,116 @@ Status RemoteCopyNode::RunSend() {
|
||||
DCHECK(send_device_ != nullptr);
|
||||
|
||||
if (send_device_->IsLocal()) {
|
||||
TF_RETURN_IF_ERROR(executor_->status());
|
||||
|
||||
op.AddInput(src_);
|
||||
|
||||
core::RefCountPtr<KernelAndDevice> kernel;
|
||||
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(&op, &kernel));
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> input_vector(1);
|
||||
TF_RETURN_IF_ERROR(src_->TensorValue(&input_vector[0]));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr, nullptr));
|
||||
status = RunLocalSend(&op);
|
||||
captured_state_->SetSendStatus(status);
|
||||
return status;
|
||||
} else {
|
||||
eager::EagerClient* eager_client;
|
||||
uint64 context_id = ctx_->GetContextId();
|
||||
TF_RETURN_IF_ERROR(ctx_->GetClient(send_device_, &eager_client));
|
||||
|
||||
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
||||
request->set_context_id(context_id);
|
||||
|
||||
auto* remote_op = request->add_queue()->mutable_operation();
|
||||
TF_RETURN_IF_ERROR(ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
src_, remote_op->add_inputs(), src_->device()));
|
||||
// Prepare the request
|
||||
EnqueueRequest request;
|
||||
request.set_context_id(ctx_->GetContextId());
|
||||
auto* remote_op = request.add_queue()->mutable_operation();
|
||||
status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
src_, remote_op->add_inputs(), src_->device());
|
||||
if (!status.ok()) {
|
||||
captured_state_->SetSendStatus(status);
|
||||
return status;
|
||||
}
|
||||
|
||||
PrepareRemoteOp(remote_op, &op);
|
||||
remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
|
||||
|
||||
auto* response = new EnqueueResponse;
|
||||
eager_client->EnqueueAsync(request.get(), response,
|
||||
[this, response](const Status& s) {
|
||||
send_status_.Update(s);
|
||||
if (!s.ok()) {
|
||||
recv_cancellation_.StartCancel();
|
||||
}
|
||||
delete response;
|
||||
});
|
||||
// Issue the RPC
|
||||
eager::EagerClient* eager_client;
|
||||
status = ctx_->GetClient(send_device_, &eager_client);
|
||||
if (!status.ok()) {
|
||||
captured_state_->SetSendStatus(status);
|
||||
return status;
|
||||
}
|
||||
|
||||
const std::shared_ptr<CapturedSharedState>& captured_state =
|
||||
captured_state_;
|
||||
EnqueueResponse* response = new EnqueueResponse;
|
||||
// If StartRecv fails very quickly, `this` can be destroyed before the
|
||||
// callback below is executed. So, we can't capture `this`.
|
||||
eager_client->EnqueueAsync(
|
||||
&request, response, [response, captured_state](const Status& s) {
|
||||
captured_state->SetSendStatus(s);
|
||||
if (!s.ok()) {
|
||||
captured_state->recv_cancellation()->StartCancel();
|
||||
}
|
||||
delete response;
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RemoteCopyNode::RunRecv() {
|
||||
Status RemoteCopyNode::RunLocalRecv(EagerOperation* op,
|
||||
std::vector<Tensor>* outputs) {
|
||||
TF_RETURN_IF_ERROR(executor_->status());
|
||||
|
||||
core::RefCountPtr<KernelAndDevice> kernel;
|
||||
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> input_vector;
|
||||
return kernel->Run(input_vector, outputs, nullptr, nullptr, nullptr,
|
||||
captured_state_->recv_cancellation());
|
||||
}
|
||||
|
||||
Status RemoteCopyNode::RunRemoteRecv(EagerOperation* op) {
|
||||
EnqueueRequest request;
|
||||
uint64 context_id = ctx_->GetContextId();
|
||||
request.set_context_id(context_id);
|
||||
auto* remote_op = request.add_queue()->mutable_operation();
|
||||
PrepareRemoteOp(remote_op, op);
|
||||
remote_op->set_id(recv_op_id_);
|
||||
|
||||
eager::EagerClient* eager_client;
|
||||
Status status = ctx_->GetClient(recv_device_, &eager_client);
|
||||
if (!status.ok()) {
|
||||
captured_state_->dst()->Poison(status);
|
||||
return status;
|
||||
}
|
||||
|
||||
EnqueueResponse* response = new EnqueueResponse;
|
||||
|
||||
// Don't issue the recv until send has completed.
|
||||
// - local send will complete very quickly.
|
||||
// - remote send will take some time, but remote->remote copy is
|
||||
// probably rare enough that we don't care much.
|
||||
// Blocks until send has completed.
|
||||
Status send_status = captured_state_->GetSendStatus();
|
||||
|
||||
const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
|
||||
Device* recv_device = recv_device_;
|
||||
Notification n;
|
||||
eager_client->EnqueueAsync(
|
||||
&request, response,
|
||||
[captured_state, response, recv_device, &n, &status](const Status& s) {
|
||||
status.Update(s);
|
||||
if (status.ok()) {
|
||||
status = captured_state->dst()->SetRemoteShape(
|
||||
response->queue_response(0).shape(0), recv_device);
|
||||
} else {
|
||||
captured_state->dst()->Poison(status);
|
||||
}
|
||||
delete response;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
Status RemoteCopyNode::StartRecv() {
|
||||
// TODO(gjn): We should consider just using the low-level RecvOp::Compute()
|
||||
// functionality here instead of constructing an Op.
|
||||
const AttrTypeMap* types;
|
||||
bool is_function = false;
|
||||
TF_RETURN_IF_ERROR(AttrTypeMapForOp("_Recv", &types, &is_function));
|
||||
Status status = AttrTypeMapForOp("_Recv", &types, &is_function);
|
||||
if (!status.ok()) {
|
||||
captured_state_->dst()->Poison(status);
|
||||
return status;
|
||||
}
|
||||
DCHECK(!is_function);
|
||||
EagerOperation op(ctx_, "_Recv", /*is_function=*/false, types);
|
||||
|
||||
@ -170,93 +251,46 @@ Status RemoteCopyNode::RunRecv() {
|
||||
op.MutableAttrs()->Set("tensor_type", src_->dtype);
|
||||
|
||||
if (recv_device_->IsLocal()) {
|
||||
TF_RETURN_IF_ERROR(executor_->status());
|
||||
|
||||
core::RefCountPtr<KernelAndDevice> kernel;
|
||||
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(&op, &kernel));
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
gtl::InlinedVector<TensorValue, 4> input_vector;
|
||||
TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, nullptr, nullptr,
|
||||
nullptr, &recv_cancellation_));
|
||||
return dst_->SetTensor(outputs[0]);
|
||||
} else {
|
||||
eager::EagerClient* eager_client;
|
||||
uint64 context_id = ctx_->GetContextId();
|
||||
TF_RETURN_IF_ERROR(ctx_->GetClient(recv_device_, &eager_client));
|
||||
|
||||
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
||||
|
||||
request->set_context_id(context_id);
|
||||
|
||||
auto* remote_op = request->add_queue()->mutable_operation();
|
||||
PrepareRemoteOp(remote_op, &op);
|
||||
remote_op->set_id(recv_op_id_);
|
||||
|
||||
EnqueueResponse response;
|
||||
Status status;
|
||||
Notification n;
|
||||
|
||||
CancellationToken token = recv_cancellation_.get_cancellation_token();
|
||||
bool already_cancelled =
|
||||
!recv_cancellation_.RegisterCallback(token, [&n, &status] {
|
||||
status.Update(errors::Cancelled(
|
||||
"Recv op is cancelled due to an error in Send op."));
|
||||
n.Notify();
|
||||
});
|
||||
|
||||
if (already_cancelled) {
|
||||
status =
|
||||
errors::Cancelled("Recv op is cancelled due to an error in Send op.");
|
||||
} else {
|
||||
// Note(fishx): When the recv op is cancelled, we doesn't clean up the
|
||||
// state on remote server. So the recv op may ran successfully on the
|
||||
// remote server even though we cancel it on client.
|
||||
eager_client->EnqueueAsync(request.get(), &response,
|
||||
[this, &n, &status](const Status& s) {
|
||||
if (recv_cancellation_.IsCancelled()) return;
|
||||
status.Update(s);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
recv_cancellation_.DeregisterCallback(token);
|
||||
std::vector<Tensor> outputs(1);
|
||||
status = RunLocalRecv(&op, &outputs);
|
||||
if (!status.ok()) {
|
||||
captured_state_->dst()->Poison(status);
|
||||
return status;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
return dst_->SetRemoteShape(response.queue_response(0).shape(0),
|
||||
recv_device_);
|
||||
return captured_state_->dst()->SetTensor(outputs[0]);
|
||||
} else {
|
||||
// Handles captured_state_->dst_ internally.
|
||||
return RunRemoteRecv(&op);
|
||||
}
|
||||
}
|
||||
|
||||
Status RemoteCopyNode::Run() {
|
||||
Status s = RunSend();
|
||||
Status s = StartSend();
|
||||
if (!s.ok()) {
|
||||
Abort(s);
|
||||
return s;
|
||||
}
|
||||
|
||||
s = RunRecv();
|
||||
if (!s.ok() && errors::IsCancelled(s) && !send_status_.ok()) {
|
||||
// In this case, Recv is cancel because Send op failed. Return the status of
|
||||
// send op instead.
|
||||
Abort(send_status_);
|
||||
return send_status_;
|
||||
}
|
||||
if (!s.ok()) {
|
||||
Abort(s);
|
||||
// StartRecv() takes care of doing the right thing to dst handle.
|
||||
// No need to poison it after this point.
|
||||
s = StartRecv();
|
||||
if (!s.ok() && errors::IsCancelled(s)) {
|
||||
Status send_status = captured_state_->GetSendStatus();
|
||||
if (!send_status.ok()) {
|
||||
// In this case, Recv is cancelled because the Send op failed. Return the
|
||||
// status of the Send op instead.
|
||||
s = send_status;
|
||||
}
|
||||
}
|
||||
|
||||
src_->Unref();
|
||||
dst_->Unref();
|
||||
ctx_->Unref();
|
||||
return s;
|
||||
}
|
||||
|
||||
void RemoteCopyNode::Abort(Status status) {
|
||||
dst_->Poison(status);
|
||||
captured_state_->dst()->Poison(status);
|
||||
src_->Unref();
|
||||
dst_->Unref();
|
||||
ctx_->Unref();
|
||||
}
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -26,11 +27,35 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
// This node supports copy a tensor:
|
||||
// Remote -> Remote
|
||||
// Local -> Remote
|
||||
// Remote -> Local
|
||||
// To copy a tensor with a host, please use copy_to_device_node instead.
|
||||
// This node supports copying a tensor in the following way:
|
||||
// - Remote -> Local:
|
||||
// We don't block on the remote _Send op and start executing the local
|
||||
// _Recv immediately after issuing the remote _Send. The local _Recv
|
||||
// kernel (or rather the special _Recv handling in KernelAndDeviceOp::Run)
|
||||
// blocks until the tensor is received. If the remote _Send (or some op
|
||||
// before it) fails, the local callback we give to EnqueueAsync will run
|
||||
// and call CancellationManager.StartCancel(). The blocked local _Recv will
|
||||
// get this notification and return with a cancelled error.
|
||||
//
|
||||
// - Local -> Remote:
|
||||
// The local _Send op is synchronous and non-blocking, thus it should complete
|
||||
// quickly. We issue remote _Recv RPC only after local _Send completes
|
||||
// successfully. At this point, the tensor to be sent is in the local
|
||||
// Rendezvous, hence, remote _Recv op will not deadlock waiting for the tensor
|
||||
// to appear.
|
||||
//
|
||||
// - Remote -> Remote:
|
||||
// We could issue both remote ops asynchronously, but if remote _Send (or some
|
||||
// op before it) fails, we don't have a good way of cancelling the remote
|
||||
// _Recv. The remote _Recv will deadlock in this case. The current approach
|
||||
// to deal with this issue is to wait for remote _Send to complete before
|
||||
// issuing remote _Recv RPC. Another option is to close the whole streaming
|
||||
// RPC that contains the deadlocked remote _Recv. This would not unblock the
|
||||
// deadlocked RPC on the remote machine without some extra code. Luckily, the
|
||||
// remote -> remote case seems to be fairly rare at this point. So, the
|
||||
// current partially synchronous approach seems fine.
|
||||
//
|
||||
// To copy a tensor within a host, please use copy_to_device_node instead.
|
||||
class RemoteCopyNode : public EagerNode {
|
||||
public:
|
||||
RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src,
|
||||
@ -43,11 +68,69 @@ class RemoteCopyNode : public EagerNode {
|
||||
void Abort(Status status) override;
|
||||
|
||||
private:
|
||||
Status RunSend();
|
||||
Status RunRecv();
|
||||
// Runs the _Send operation locally or remotely.
|
||||
// An error return value indicates that _Send did not run successfully.
|
||||
// An OK return value does NOT necessarily indicate that _Send has completed
|
||||
// successfully. It might still fail after this method returns.
|
||||
// StartSend() makes sure that captured_state_->send_status_ is set to the
|
||||
// final _Send status after captured_state->send_done_.WaitForNotification()
|
||||
// returns.
|
||||
Status StartSend();
|
||||
|
||||
// Synchronously runs local send `op` and returns its status.
|
||||
Status RunLocalSend(EagerOperation* op);
|
||||
|
||||
// Runs the _Recv operation locally or remotely.
|
||||
// An error return value indicates that _Recv did not run successfully. It
|
||||
// does not indicate that _Send op has completed since StartRecv could have
|
||||
// encountered an error before waiting for _Send's completion.
|
||||
// An OK return value does NOT necessarily indicate that _Recv has completed
|
||||
// successfully (it does now, but won't when streaming RPCs are turned on).
|
||||
// StartRecv() makes sure that dst_ tensor handle is handled correctly
|
||||
// (potentially after this methods returns); a tensor is set in the local
|
||||
// case, a remote shape is set in the remote case, the dst_ handle is
|
||||
// poisoned in either case if there is an error.
|
||||
Status StartRecv();
|
||||
|
||||
// Synchronously runs local receive `op` and returns its status.
|
||||
// Does not wait for the send to complete before running receive.
|
||||
Status RunLocalRecv(EagerOperation* op, std::vector<Tensor>* outputs);
|
||||
|
||||
// Waits for send to complete, then issues remote receive `op` and
|
||||
// returns its status.
|
||||
Status RunRemoteRecv(EagerOperation* op);
|
||||
|
||||
// State that is captured by Send and/or Recv callbacks (depending on which
|
||||
// one(s) is remote) and outlives this node in the case of remote->remote
|
||||
// copy.
|
||||
class CapturedSharedState {
|
||||
public:
|
||||
explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); }
|
||||
~CapturedSharedState() { dst_->Unref(); }
|
||||
|
||||
void SetSendStatus(Status status) {
|
||||
send_status_.Update(status);
|
||||
send_done_.Notify();
|
||||
}
|
||||
|
||||
Status GetSendStatus() {
|
||||
send_done_.WaitForNotification();
|
||||
return send_status_;
|
||||
}
|
||||
|
||||
TensorHandle* dst() { return dst_; }
|
||||
CancellationManager* recv_cancellation() { return &recv_cancellation_; }
|
||||
|
||||
private:
|
||||
TensorHandle* const dst_;
|
||||
CancellationManager recv_cancellation_;
|
||||
// send_status_ is safe to read only after send_done_.WaitForNotification()
|
||||
// has returned.
|
||||
Status send_status_;
|
||||
Notification send_done_;
|
||||
};
|
||||
|
||||
TensorHandle* const src_;
|
||||
TensorHandle* const dst_;
|
||||
EagerContext* const ctx_;
|
||||
EagerExecutor* const executor_;
|
||||
Device* const send_device_;
|
||||
@ -55,8 +138,7 @@ class RemoteCopyNode : public EagerNode {
|
||||
const string wire_id_;
|
||||
const uint64 recv_op_id_;
|
||||
|
||||
CancellationManager recv_cancellation_;
|
||||
Status send_status_;
|
||||
std::shared_ptr<CapturedSharedState> captured_state_;
|
||||
};
|
||||
|
||||
} // namespace eager
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user