Make RemoteCopyNode ready for streaming RPCs
This required:
- Adding CapturedSharedState for the state that must survive after RemoteCopyNode
is destroyed.
- Allocating receive response dynamically. This is not needed right now, but will be
needed with streaming. I would like to keep the streaming change to the minimum.
Also, fix a race condition accessing send_status_.
PiperOrigin-RevId: 260567277
This commit is contained in:
parent
654e1a8b1a
commit
f7dc3f6961
@ -71,25 +71,42 @@ RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
|
|||||||
Device* recv_device, uint64 recv_op_id)
|
Device* recv_device, uint64 recv_op_id)
|
||||||
: EagerNode(),
|
: EagerNode(),
|
||||||
src_(src),
|
src_(src),
|
||||||
dst_(dst),
|
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
executor_(executor),
|
executor_(executor),
|
||||||
send_device_(src->DeviceOrHostCPU(ctx)),
|
send_device_(src->DeviceOrHostCPU(ctx)),
|
||||||
recv_device_(recv_device),
|
recv_device_(recv_device),
|
||||||
wire_id_(GetUniqueWireID()),
|
wire_id_(GetUniqueWireID()),
|
||||||
recv_op_id_(recv_op_id) {
|
recv_op_id_(recv_op_id),
|
||||||
|
captured_state_(std::make_shared<CapturedSharedState>(dst)) {
|
||||||
DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal());
|
DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal());
|
||||||
src_->Ref();
|
src_->Ref();
|
||||||
dst_->Ref();
|
|
||||||
ctx_->Ref();
|
ctx_->Ref();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RemoteCopyNode::RunSend() {
|
Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
|
||||||
|
TF_RETURN_IF_ERROR(executor_->status());
|
||||||
|
|
||||||
|
op->AddInput(src_);
|
||||||
|
|
||||||
|
core::RefCountPtr<KernelAndDevice> kernel;
|
||||||
|
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
|
||||||
|
|
||||||
|
gtl::InlinedVector<TensorValue, 4> input_vector(1);
|
||||||
|
TF_RETURN_IF_ERROR(src_->TensorValue(&input_vector[0]));
|
||||||
|
|
||||||
|
return kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RemoteCopyNode::StartSend() {
|
||||||
// TODO(gjn): We should consider just using the low-level SendOp::Compute()
|
// TODO(gjn): We should consider just using the low-level SendOp::Compute()
|
||||||
// functionality here instead of constructing an Op.
|
// functionality here instead of constructing an Op.
|
||||||
const AttrTypeMap* types;
|
const AttrTypeMap* types;
|
||||||
bool is_function = false;
|
bool is_function = false;
|
||||||
TF_RETURN_IF_ERROR(AttrTypeMapForOp("_Send", &types, &is_function));
|
Status status = AttrTypeMapForOp("_Send", &types, &is_function);
|
||||||
|
if (!status.ok()) {
|
||||||
|
captured_state_->SetSendStatus(status);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
DCHECK(!is_function);
|
DCHECK(!is_function);
|
||||||
EagerOperation op(ctx_, "_Send", /*is_function=*/false, types);
|
EagerOperation op(ctx_, "_Send", /*is_function=*/false, types);
|
||||||
|
|
||||||
@ -108,52 +125,116 @@ Status RemoteCopyNode::RunSend() {
|
|||||||
DCHECK(send_device_ != nullptr);
|
DCHECK(send_device_ != nullptr);
|
||||||
|
|
||||||
if (send_device_->IsLocal()) {
|
if (send_device_->IsLocal()) {
|
||||||
TF_RETURN_IF_ERROR(executor_->status());
|
status = RunLocalSend(&op);
|
||||||
|
captured_state_->SetSendStatus(status);
|
||||||
op.AddInput(src_);
|
return status;
|
||||||
|
|
||||||
core::RefCountPtr<KernelAndDevice> kernel;
|
|
||||||
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(&op, &kernel));
|
|
||||||
|
|
||||||
gtl::InlinedVector<TensorValue, 4> input_vector(1);
|
|
||||||
TF_RETURN_IF_ERROR(src_->TensorValue(&input_vector[0]));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr, nullptr));
|
|
||||||
} else {
|
} else {
|
||||||
eager::EagerClient* eager_client;
|
// Prepare the request
|
||||||
uint64 context_id = ctx_->GetContextId();
|
EnqueueRequest request;
|
||||||
TF_RETURN_IF_ERROR(ctx_->GetClient(send_device_, &eager_client));
|
request.set_context_id(ctx_->GetContextId());
|
||||||
|
auto* remote_op = request.add_queue()->mutable_operation();
|
||||||
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||||
request->set_context_id(context_id);
|
src_, remote_op->add_inputs(), src_->device());
|
||||||
|
if (!status.ok()) {
|
||||||
auto* remote_op = request->add_queue()->mutable_operation();
|
captured_state_->SetSendStatus(status);
|
||||||
TF_RETURN_IF_ERROR(ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
|
return status;
|
||||||
src_, remote_op->add_inputs(), src_->device()));
|
}
|
||||||
|
|
||||||
PrepareRemoteOp(remote_op, &op);
|
PrepareRemoteOp(remote_op, &op);
|
||||||
remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
|
remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
|
||||||
|
|
||||||
auto* response = new EnqueueResponse;
|
// Issue the RPC
|
||||||
eager_client->EnqueueAsync(request.get(), response,
|
eager::EagerClient* eager_client;
|
||||||
[this, response](const Status& s) {
|
status = ctx_->GetClient(send_device_, &eager_client);
|
||||||
send_status_.Update(s);
|
if (!status.ok()) {
|
||||||
if (!s.ok()) {
|
captured_state_->SetSendStatus(status);
|
||||||
recv_cancellation_.StartCancel();
|
return status;
|
||||||
}
|
}
|
||||||
delete response;
|
|
||||||
});
|
const std::shared_ptr<CapturedSharedState>& captured_state =
|
||||||
|
captured_state_;
|
||||||
|
EnqueueResponse* response = new EnqueueResponse;
|
||||||
|
// If StartRecv fails very quickly, `this` can be destroyed before the
|
||||||
|
// callback below is executed. So, we can't capture `this`.
|
||||||
|
eager_client->EnqueueAsync(
|
||||||
|
&request, response, [response, captured_state](const Status& s) {
|
||||||
|
captured_state->SetSendStatus(s);
|
||||||
|
if (!s.ok()) {
|
||||||
|
captured_state->recv_cancellation()->StartCancel();
|
||||||
|
}
|
||||||
|
delete response;
|
||||||
|
});
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RemoteCopyNode::RunRecv() {
|
Status RemoteCopyNode::RunLocalRecv(EagerOperation* op,
|
||||||
|
std::vector<Tensor>* outputs) {
|
||||||
|
TF_RETURN_IF_ERROR(executor_->status());
|
||||||
|
|
||||||
|
core::RefCountPtr<KernelAndDevice> kernel;
|
||||||
|
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
|
||||||
|
|
||||||
|
gtl::InlinedVector<TensorValue, 4> input_vector;
|
||||||
|
return kernel->Run(input_vector, outputs, nullptr, nullptr, nullptr,
|
||||||
|
captured_state_->recv_cancellation());
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RemoteCopyNode::RunRemoteRecv(EagerOperation* op) {
|
||||||
|
EnqueueRequest request;
|
||||||
|
uint64 context_id = ctx_->GetContextId();
|
||||||
|
request.set_context_id(context_id);
|
||||||
|
auto* remote_op = request.add_queue()->mutable_operation();
|
||||||
|
PrepareRemoteOp(remote_op, op);
|
||||||
|
remote_op->set_id(recv_op_id_);
|
||||||
|
|
||||||
|
eager::EagerClient* eager_client;
|
||||||
|
Status status = ctx_->GetClient(recv_device_, &eager_client);
|
||||||
|
if (!status.ok()) {
|
||||||
|
captured_state_->dst()->Poison(status);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
EnqueueResponse* response = new EnqueueResponse;
|
||||||
|
|
||||||
|
// Don't issue the recv until send has completed.
|
||||||
|
// - local send will complete very quickly.
|
||||||
|
// - remote send will take some time, but remote->remote copy is
|
||||||
|
// probably rare enough that we don't care much.
|
||||||
|
// Blocks until send has completed.
|
||||||
|
Status send_status = captured_state_->GetSendStatus();
|
||||||
|
|
||||||
|
const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
|
||||||
|
Device* recv_device = recv_device_;
|
||||||
|
Notification n;
|
||||||
|
eager_client->EnqueueAsync(
|
||||||
|
&request, response,
|
||||||
|
[captured_state, response, recv_device, &n, &status](const Status& s) {
|
||||||
|
status.Update(s);
|
||||||
|
if (status.ok()) {
|
||||||
|
status = captured_state->dst()->SetRemoteShape(
|
||||||
|
response->queue_response(0).shape(0), recv_device);
|
||||||
|
} else {
|
||||||
|
captured_state->dst()->Poison(status);
|
||||||
|
}
|
||||||
|
delete response;
|
||||||
|
n.Notify();
|
||||||
|
});
|
||||||
|
n.WaitForNotification();
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RemoteCopyNode::StartRecv() {
|
||||||
// TODO(gjn): We should consider just using the low-level RecvOp::Compute()
|
// TODO(gjn): We should consider just using the low-level RecvOp::Compute()
|
||||||
// functionality here instead of constructing an Op.
|
// functionality here instead of constructing an Op.
|
||||||
const AttrTypeMap* types;
|
const AttrTypeMap* types;
|
||||||
bool is_function = false;
|
bool is_function = false;
|
||||||
TF_RETURN_IF_ERROR(AttrTypeMapForOp("_Recv", &types, &is_function));
|
Status status = AttrTypeMapForOp("_Recv", &types, &is_function);
|
||||||
|
if (!status.ok()) {
|
||||||
|
captured_state_->dst()->Poison(status);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
DCHECK(!is_function);
|
DCHECK(!is_function);
|
||||||
EagerOperation op(ctx_, "_Recv", /*is_function=*/false, types);
|
EagerOperation op(ctx_, "_Recv", /*is_function=*/false, types);
|
||||||
|
|
||||||
@ -170,93 +251,46 @@ Status RemoteCopyNode::RunRecv() {
|
|||||||
op.MutableAttrs()->Set("tensor_type", src_->dtype);
|
op.MutableAttrs()->Set("tensor_type", src_->dtype);
|
||||||
|
|
||||||
if (recv_device_->IsLocal()) {
|
if (recv_device_->IsLocal()) {
|
||||||
TF_RETURN_IF_ERROR(executor_->status());
|
std::vector<Tensor> outputs(1);
|
||||||
|
status = RunLocalRecv(&op, &outputs);
|
||||||
core::RefCountPtr<KernelAndDevice> kernel;
|
if (!status.ok()) {
|
||||||
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(&op, &kernel));
|
captured_state_->dst()->Poison(status);
|
||||||
|
return status;
|
||||||
std::vector<Tensor> outputs;
|
|
||||||
gtl::InlinedVector<TensorValue, 4> input_vector;
|
|
||||||
TF_RETURN_IF_ERROR(kernel->Run(input_vector, &outputs, nullptr, nullptr,
|
|
||||||
nullptr, &recv_cancellation_));
|
|
||||||
return dst_->SetTensor(outputs[0]);
|
|
||||||
} else {
|
|
||||||
eager::EagerClient* eager_client;
|
|
||||||
uint64 context_id = ctx_->GetContextId();
|
|
||||||
TF_RETURN_IF_ERROR(ctx_->GetClient(recv_device_, &eager_client));
|
|
||||||
|
|
||||||
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
|
||||||
|
|
||||||
request->set_context_id(context_id);
|
|
||||||
|
|
||||||
auto* remote_op = request->add_queue()->mutable_operation();
|
|
||||||
PrepareRemoteOp(remote_op, &op);
|
|
||||||
remote_op->set_id(recv_op_id_);
|
|
||||||
|
|
||||||
EnqueueResponse response;
|
|
||||||
Status status;
|
|
||||||
Notification n;
|
|
||||||
|
|
||||||
CancellationToken token = recv_cancellation_.get_cancellation_token();
|
|
||||||
bool already_cancelled =
|
|
||||||
!recv_cancellation_.RegisterCallback(token, [&n, &status] {
|
|
||||||
status.Update(errors::Cancelled(
|
|
||||||
"Recv op is cancelled due to an error in Send op."));
|
|
||||||
n.Notify();
|
|
||||||
});
|
|
||||||
|
|
||||||
if (already_cancelled) {
|
|
||||||
status =
|
|
||||||
errors::Cancelled("Recv op is cancelled due to an error in Send op.");
|
|
||||||
} else {
|
|
||||||
// Note(fishx): When the recv op is cancelled, we doesn't clean up the
|
|
||||||
// state on remote server. So the recv op may ran successfully on the
|
|
||||||
// remote server even though we cancel it on client.
|
|
||||||
eager_client->EnqueueAsync(request.get(), &response,
|
|
||||||
[this, &n, &status](const Status& s) {
|
|
||||||
if (recv_cancellation_.IsCancelled()) return;
|
|
||||||
status.Update(s);
|
|
||||||
n.Notify();
|
|
||||||
});
|
|
||||||
n.WaitForNotification();
|
|
||||||
recv_cancellation_.DeregisterCallback(token);
|
|
||||||
}
|
}
|
||||||
|
return captured_state_->dst()->SetTensor(outputs[0]);
|
||||||
TF_RETURN_IF_ERROR(status);
|
} else {
|
||||||
|
// Handles captured_state_->dst_ internally.
|
||||||
return dst_->SetRemoteShape(response.queue_response(0).shape(0),
|
return RunRemoteRecv(&op);
|
||||||
recv_device_);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RemoteCopyNode::Run() {
|
Status RemoteCopyNode::Run() {
|
||||||
Status s = RunSend();
|
Status s = StartSend();
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
Abort(s);
|
Abort(s);
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
s = RunRecv();
|
// StartRecv() takes care of doing the right thing to dst handle.
|
||||||
if (!s.ok() && errors::IsCancelled(s) && !send_status_.ok()) {
|
// No need to poison it after this point.
|
||||||
// In this case, Recv is cancel because Send op failed. Return the status of
|
s = StartRecv();
|
||||||
// send op instead.
|
if (!s.ok() && errors::IsCancelled(s)) {
|
||||||
Abort(send_status_);
|
Status send_status = captured_state_->GetSendStatus();
|
||||||
return send_status_;
|
if (!send_status.ok()) {
|
||||||
}
|
// In this case, Recv is cancelled because the Send op failed. Return the
|
||||||
if (!s.ok()) {
|
// status of the Send op instead.
|
||||||
Abort(s);
|
s = send_status;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
src_->Unref();
|
src_->Unref();
|
||||||
dst_->Unref();
|
|
||||||
ctx_->Unref();
|
ctx_->Unref();
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
void RemoteCopyNode::Abort(Status status) {
|
void RemoteCopyNode::Abort(Status status) {
|
||||||
dst_->Poison(status);
|
captured_state_->dst()->Poison(status);
|
||||||
src_->Unref();
|
src_->Unref();
|
||||||
dst_->Unref();
|
|
||||||
ctx_->Unref();
|
ctx_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/framework/cancellation.h"
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
@ -26,11 +27,35 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace eager {
|
namespace eager {
|
||||||
|
|
||||||
// This node supports copy a tensor:
|
// This node supports copying a tensor in the following way:
|
||||||
// Remote -> Remote
|
// - Remote -> Local:
|
||||||
// Local -> Remote
|
// We don't block on the remote _Send op and start executing the local
|
||||||
// Remote -> Local
|
// _Recv immediately after issuing the remote _Send. The local _Recv
|
||||||
// To copy a tensor with a host, please use copy_to_device_node instead.
|
// kernel (or rather the special _Recv handling in KernelAndDeviceOp::Run)
|
||||||
|
// blocks until the tensor is received. If the remote _Send (or some op
|
||||||
|
// before it) fails, the local callback we give to EnqueueAsync will run
|
||||||
|
// and call CancellationManager.StartCancel(). The blocked local _Recv will
|
||||||
|
// get this notification and return with a cancelled error.
|
||||||
|
//
|
||||||
|
// - Local -> Remote:
|
||||||
|
// The local _Send op is synchronous and non-blocking, thus it should complete
|
||||||
|
// quickly. We issue remote _Recv RPC only after local _Send completes
|
||||||
|
// successfully. At this point, the tensor to be sent is in the local
|
||||||
|
// Rendezvous, hence, remote _Recv op will not deadlock waiting for the tensor
|
||||||
|
// to appear.
|
||||||
|
//
|
||||||
|
// - Remote -> Remote:
|
||||||
|
// We could issue both remote ops asynchronously, but if remote _Send (or some
|
||||||
|
// op before it) fails, we don't have a good way of cancelling the remote
|
||||||
|
// _Recv. The remote _Recv will deadlock in this case. The current approach
|
||||||
|
// to deal with this issue is to wait for remote _Send to complete before
|
||||||
|
// issuing remote _Recv RPC. Another option is to close the whole streaming
|
||||||
|
// RPC that contains the deadlocked remote _Recv. This would not unblock the
|
||||||
|
// deadlocked RPC on the remote machine without some extra code. Luckily, the
|
||||||
|
// remote -> remote case seems to be fairly rare at this point. So, the
|
||||||
|
// current partially synchronous approach seems fine.
|
||||||
|
//
|
||||||
|
// To copy a tensor within a host, please use copy_to_device_node instead.
|
||||||
class RemoteCopyNode : public EagerNode {
|
class RemoteCopyNode : public EagerNode {
|
||||||
public:
|
public:
|
||||||
RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src,
|
RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src,
|
||||||
@ -43,11 +68,69 @@ class RemoteCopyNode : public EagerNode {
|
|||||||
void Abort(Status status) override;
|
void Abort(Status status) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status RunSend();
|
// Runs the _Send operation locally or remotely.
|
||||||
Status RunRecv();
|
// An error return value indicates that _Send did not run successfully.
|
||||||
|
// An OK return value does NOT necessarily indicate that _Send has completed
|
||||||
|
// successfully. It might still fail after this method returns.
|
||||||
|
// StartSend() makes sure that captured_state_->send_status_ is set to the
|
||||||
|
// final _Send status after captured_state->send_done_.WaitForNotification()
|
||||||
|
// returns.
|
||||||
|
Status StartSend();
|
||||||
|
|
||||||
|
// Synchronously runs local send `op` and returns its status.
|
||||||
|
Status RunLocalSend(EagerOperation* op);
|
||||||
|
|
||||||
|
// Runs the _Recv operation locally or remotely.
|
||||||
|
// An error return value indicates that _Recv did not run successfully. It
|
||||||
|
// does not indicate that _Send op has completed since StartRecv could have
|
||||||
|
// encountered an error before waiting for _Send's completion.
|
||||||
|
// An OK return value does NOT necessarily indicate that _Recv has completed
|
||||||
|
// successfully (it does now, but won't when streaming RPCs are turned on).
|
||||||
|
// StartRecv() makes sure that dst_ tensor handle is handled correctly
|
||||||
|
// (potentially after this methods returns); a tensor is set in the local
|
||||||
|
// case, a remote shape is set in the remote case, the dst_ handle is
|
||||||
|
// poisoned in either case if there is an error.
|
||||||
|
Status StartRecv();
|
||||||
|
|
||||||
|
// Synchronously runs local receive `op` and returns its status.
|
||||||
|
// Does not wait for the send to complete before running receive.
|
||||||
|
Status RunLocalRecv(EagerOperation* op, std::vector<Tensor>* outputs);
|
||||||
|
|
||||||
|
// Waits for send to complete, then issues remote receive `op` and
|
||||||
|
// returns its status.
|
||||||
|
Status RunRemoteRecv(EagerOperation* op);
|
||||||
|
|
||||||
|
// State that is captured by Send and/or Recv callbacks (depending on which
|
||||||
|
// one(s) is remote) and outlives this node in the case of remote->remote
|
||||||
|
// copy.
|
||||||
|
class CapturedSharedState {
|
||||||
|
public:
|
||||||
|
explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); }
|
||||||
|
~CapturedSharedState() { dst_->Unref(); }
|
||||||
|
|
||||||
|
void SetSendStatus(Status status) {
|
||||||
|
send_status_.Update(status);
|
||||||
|
send_done_.Notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetSendStatus() {
|
||||||
|
send_done_.WaitForNotification();
|
||||||
|
return send_status_;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorHandle* dst() { return dst_; }
|
||||||
|
CancellationManager* recv_cancellation() { return &recv_cancellation_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
TensorHandle* const dst_;
|
||||||
|
CancellationManager recv_cancellation_;
|
||||||
|
// send_status_ is safe to read only after send_done_.WaitForNotification()
|
||||||
|
// has returned.
|
||||||
|
Status send_status_;
|
||||||
|
Notification send_done_;
|
||||||
|
};
|
||||||
|
|
||||||
TensorHandle* const src_;
|
TensorHandle* const src_;
|
||||||
TensorHandle* const dst_;
|
|
||||||
EagerContext* const ctx_;
|
EagerContext* const ctx_;
|
||||||
EagerExecutor* const executor_;
|
EagerExecutor* const executor_;
|
||||||
Device* const send_device_;
|
Device* const send_device_;
|
||||||
@ -55,8 +138,7 @@ class RemoteCopyNode : public EagerNode {
|
|||||||
const string wire_id_;
|
const string wire_id_;
|
||||||
const uint64 recv_op_id_;
|
const uint64 recv_op_id_;
|
||||||
|
|
||||||
CancellationManager recv_cancellation_;
|
std::shared_ptr<CapturedSharedState> captured_state_;
|
||||||
Status send_status_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace eager
|
} // namespace eager
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user