Support running a remote function with packed input handles.
- Support copying a packed TensorHandle from a client to a remote worker. PiperOrigin-RevId: 311404609 Change-Id: Iadf2c7793dc3631f7be05de611d059733bbfdd63
This commit is contained in:
parent
4f6a3a4db0
commit
8588e0aab8
@ -434,7 +434,7 @@ string AddVariablesFunction() {
|
|||||||
return def.SerializeAsString();
|
return def.SerializeAsString();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, TestFunctionWithPackedInput) {
|
void TestFunctionWithPackedInput(const bool remote) {
|
||||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||||
|
|
||||||
// This server def has the task index set to 0.
|
// This server def has the task index set to 0.
|
||||||
@ -502,6 +502,10 @@ TEST(CAPI, TestFunctionWithPackedInput) {
|
|||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
TFE_OpAddInput(func, packed_handle, status);
|
TFE_OpAddInput(func, packed_handle, status);
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
if (remote) {
|
||||||
|
TFE_OpSetDevice(func, task1_name, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||||
int num_retvals = 1;
|
int num_retvals = 1;
|
||||||
@ -537,6 +541,14 @@ TEST(CAPI, TestFunctionWithPackedInput) {
|
|||||||
worker_server2.release();
|
worker_server2.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestLocalFunctionWithPackedInput) {
|
||||||
|
TestFunctionWithPackedInput(/*remote=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
|
||||||
|
TestFunctionWithPackedInput(/*remote=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||||
|
|
||||||
|
@ -318,17 +318,14 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||||
|
const tensorflow::DataType dtype,
|
||||||
|
const tensorflow::TensorShape& shape,
|
||||||
EagerContext* ctx,
|
EagerContext* ctx,
|
||||||
TensorHandle** packed_handle) {
|
TensorHandle** packed_handle) {
|
||||||
if (handles.empty()) {
|
if (handles.empty()) {
|
||||||
return errors::InvalidArgument("Handles should not be empty.");
|
return errors::InvalidArgument("Handles should not be empty.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the dtype and shape from the fisrt handle since all handles have the
|
|
||||||
// same dtype and shape.
|
|
||||||
tensorflow::DataType dtype = handles.at(0)->dtype;
|
|
||||||
tensorflow::TensorShape shape;
|
|
||||||
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
|
|
||||||
ResourceHandleInfo resource_handle_info;
|
ResourceHandleInfo resource_handle_info;
|
||||||
if (dtype == DT_RESOURCE) {
|
if (dtype == DT_RESOURCE) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -360,6 +357,22 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||||
|
EagerContext* ctx,
|
||||||
|
TensorHandle** packed_handle) {
|
||||||
|
if (handles.empty()) {
|
||||||
|
return errors::InvalidArgument("Handles should not be empty.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the dtype and shape from the fisrt handle since all handles have the
|
||||||
|
// same dtype and shape.
|
||||||
|
tensorflow::DataType dtype = handles.at(0)->dtype;
|
||||||
|
tensorflow::TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
|
||||||
|
return CreatePackedHandle(std::move(handles), dtype, shape, ctx,
|
||||||
|
packed_handle);
|
||||||
|
}
|
||||||
|
|
||||||
TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
|
TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
|
||||||
const tensorflow::DataType dtype,
|
const tensorflow::DataType dtype,
|
||||||
const tensorflow::TensorShape& shape,
|
const tensorflow::TensorShape& shape,
|
||||||
|
@ -91,6 +91,11 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
|||||||
// Create a handle which packs the given handles of the same dtype and shape.
|
// Create a handle which packs the given handles of the same dtype and shape.
|
||||||
// If handles are on different devices, assign the packed handle to a
|
// If handles are on different devices, assign the packed handle to a
|
||||||
// CompositeDevice.
|
// CompositeDevice.
|
||||||
|
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||||
|
const tensorflow::DataType dtype,
|
||||||
|
const tensorflow::TensorShape& shape,
|
||||||
|
EagerContext* ctx,
|
||||||
|
TensorHandle** packed_handle);
|
||||||
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||||
EagerContext* ctx,
|
EagerContext* ctx,
|
||||||
TensorHandle** packed_handle);
|
TensorHandle** packed_handle);
|
||||||
|
@ -524,6 +524,8 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
|
|||||||
s = context->Context()->Executor().AddOrExecute(std::move(node));
|
s = context->Context()->Executor().AddOrExecute(std::move(node));
|
||||||
} else if (item.has_send_tensor()) {
|
} else if (item.has_send_tensor()) {
|
||||||
s = SendTensor(item.send_tensor(), context->Context());
|
s = SendTensor(item.send_tensor(), context->Context());
|
||||||
|
} else if (item.has_send_packed_handle()) {
|
||||||
|
s = SendPackedHandle(item.send_packed_handle(), context->Context());
|
||||||
} else if (item.has_register_function()) {
|
} else if (item.has_register_function()) {
|
||||||
s = RegisterFunction(item.register_function(), context->Context());
|
s = RegisterFunction(item.register_function(), context->Context());
|
||||||
} else if (item.has_cleanup_function()) {
|
} else if (item.has_cleanup_function()) {
|
||||||
@ -643,6 +645,52 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status EagerServiceImpl::SendPackedHandle(
|
||||||
|
const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) {
|
||||||
|
if (send_packed_handle.handles().empty()) {
|
||||||
|
return errors::InvalidArgument("Handles should not be empty.");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<tensorflow::TensorHandle*> handles;
|
||||||
|
handles.resize(send_packed_handle.handles_size());
|
||||||
|
for (int i = 0; i < send_packed_handle.handles_size(); ++i) {
|
||||||
|
const auto& item = send_packed_handle.handles(i);
|
||||||
|
if (item.has_local_handle()) {
|
||||||
|
Tensor tensor;
|
||||||
|
if (!ParseTensorProtoToTensor(item.local_handle().tensor(), &tensor)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid TensorProto: ",
|
||||||
|
item.local_handle().tensor().DebugString());
|
||||||
|
}
|
||||||
|
Device* op_device = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
|
||||||
|
item.local_handle().device().c_str(), &op_device));
|
||||||
|
handles[i] = TensorHandle::CreateLocalHandle(
|
||||||
|
std::move(tensor), /*d=*/nullptr, op_device, eager_context);
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
|
||||||
|
item.remote_handle(), &handles[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::TensorHandle* packed_handle = nullptr;
|
||||||
|
std::vector<tensorflow::TensorHandle*> handles_to_pack = handles;
|
||||||
|
// Create a unshaped packed TensorHandle.
|
||||||
|
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
|
||||||
|
std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
|
||||||
|
eager_context, &packed_handle));
|
||||||
|
|
||||||
|
for (auto* h : handles) {
|
||||||
|
// Unref handle since it has a ref in the packed handle now.
|
||||||
|
h->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
|
eager_context->RemoteMgr()->AddOperationOutputs({packed_handle},
|
||||||
|
send_packed_handle.op_id());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::Status EagerServiceImpl::GetServerContext(
|
tensorflow::Status EagerServiceImpl::GetServerContext(
|
||||||
uint64 context_id, ServerContext** server_context) {
|
uint64 context_id, ServerContext** server_context) {
|
||||||
tf_shared_lock l(contexts_mu_);
|
tf_shared_lock l(contexts_mu_);
|
||||||
|
@ -212,6 +212,8 @@ class EagerServiceImpl {
|
|||||||
QueueResponse* queue_response);
|
QueueResponse* queue_response);
|
||||||
Status SendTensor(const SendTensorOp& send_tensor,
|
Status SendTensor(const SendTensorOp& send_tensor,
|
||||||
EagerContext* eager_context);
|
EagerContext* eager_context);
|
||||||
|
Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle,
|
||||||
|
EagerContext* eager_context);
|
||||||
Status RegisterFunction(const RegisterFunctionOp& register_function,
|
Status RegisterFunction(const RegisterFunctionOp& register_function,
|
||||||
EagerContext* eager_context);
|
EagerContext* eager_context);
|
||||||
Status CleanupFunction(const CleanupFunctionOp& cleanup_function);
|
Status CleanupFunction(const CleanupFunctionOp& cleanup_function);
|
||||||
|
@ -881,6 +881,108 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
|
|||||||
&close_context_response));
|
&close_context_response));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test serializes and sends a pack TensorHandle.
|
||||||
|
TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
|
||||||
|
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
||||||
|
|
||||||
|
const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||||
|
const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||||
|
|
||||||
|
uint64 context_id = random::New64();
|
||||||
|
CreateContextRequest request;
|
||||||
|
auto* server_def = request.mutable_server_def();
|
||||||
|
server_def->set_job_name("localhost");
|
||||||
|
server_def->set_task_index(0);
|
||||||
|
request.add_cluster_device_attributes()->set_name(device0);
|
||||||
|
request.add_cluster_device_attributes()->set_name(device1);
|
||||||
|
request.add_cluster_device_attributes()->set_name(device2);
|
||||||
|
request.set_context_id(context_id);
|
||||||
|
CreateContextResponse response;
|
||||||
|
|
||||||
|
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
||||||
|
|
||||||
|
EnqueueRequest remote_enqueue_request;
|
||||||
|
remote_enqueue_request.set_context_id(context_id);
|
||||||
|
EnqueueResponse remote_enqueue_response;
|
||||||
|
|
||||||
|
// Copy a tensor to device0
|
||||||
|
auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
|
||||||
|
send_tensor->set_op_id(1);
|
||||||
|
SetTensorProto(send_tensor->add_tensors());
|
||||||
|
|
||||||
|
// Copy a packed handle to device0
|
||||||
|
auto* send_packed_handle =
|
||||||
|
remote_enqueue_request.add_queue()->mutable_send_packed_handle();
|
||||||
|
send_packed_handle->set_op_id(3);
|
||||||
|
RemoteTensorHandle* remote_handle =
|
||||||
|
send_packed_handle->add_handles()->mutable_remote_handle();
|
||||||
|
remote_handle->set_op_id(send_tensor->op_id());
|
||||||
|
remote_handle->set_output_num(0);
|
||||||
|
remote_handle->set_op_device(device0);
|
||||||
|
remote_handle->set_device(device0);
|
||||||
|
|
||||||
|
SendPackedHandleOp::LocalTensorHandle* lcoal_handle =
|
||||||
|
send_packed_handle->add_handles()->mutable_local_handle();
|
||||||
|
SetTensorProto(lcoal_handle->mutable_tensor());
|
||||||
|
lcoal_handle->set_device(device1);
|
||||||
|
|
||||||
|
remote_handle = send_packed_handle->add_handles()->mutable_remote_handle();
|
||||||
|
remote_handle->set_op_id(2);
|
||||||
|
remote_handle->set_output_num(5);
|
||||||
|
remote_handle->set_op_device(device2);
|
||||||
|
remote_handle->set_device(device2);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
|
||||||
|
&remote_enqueue_response));
|
||||||
|
|
||||||
|
tensorflow::TensorHandle* packed_handle;
|
||||||
|
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
||||||
|
context_id, RemoteTensorHandleInternal(3, 0), &packed_handle));
|
||||||
|
|
||||||
|
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
|
||||||
|
EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
|
||||||
|
|
||||||
|
TensorHandle* handle0 = nullptr;
|
||||||
|
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
|
||||||
|
EXPECT_EQ(handle0->Type(), TensorHandle::LOCAL);
|
||||||
|
EXPECT_EQ(handle0->op_device()->name(), device0);
|
||||||
|
const Tensor* t0 = nullptr;
|
||||||
|
TF_ASSERT_OK(handle0->Tensor(&t0));
|
||||||
|
auto actual = t0->flat<float>();
|
||||||
|
EXPECT_EQ(4, actual.size());
|
||||||
|
EXPECT_EQ(1.0, actual(0));
|
||||||
|
EXPECT_EQ(2.0, actual(1));
|
||||||
|
EXPECT_EQ(3.0, actual(2));
|
||||||
|
EXPECT_EQ(4.0, actual(3));
|
||||||
|
|
||||||
|
TensorHandle* handle1 = nullptr;
|
||||||
|
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(1, &handle1));
|
||||||
|
EXPECT_EQ(handle1->Type(), TensorHandle::LOCAL);
|
||||||
|
EXPECT_EQ(handle1->op_device()->name(), device1);
|
||||||
|
const Tensor* t1 = nullptr;
|
||||||
|
TF_ASSERT_OK(handle0->Tensor(&t1));
|
||||||
|
EXPECT_EQ(t1, t0);
|
||||||
|
|
||||||
|
TensorHandle* handle2 = nullptr;
|
||||||
|
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(2, &handle2));
|
||||||
|
EXPECT_EQ(handle2->Type(), TensorHandle::REMOTE);
|
||||||
|
EXPECT_EQ(handle2->op_device()->name(), device2);
|
||||||
|
int64 op_id;
|
||||||
|
int32 output_num;
|
||||||
|
TF_ASSERT_OK(handle2->RemoteAddressUntilReady(
|
||||||
|
absl::get<Device*>(handle2->device()), &op_id, &output_num));
|
||||||
|
EXPECT_EQ(op_id, 2);
|
||||||
|
EXPECT_EQ(output_num, 5);
|
||||||
|
|
||||||
|
CloseContextRequest close_context_request;
|
||||||
|
close_context_request.set_context_id(context_id);
|
||||||
|
close_context_request.set_context_view_id(0);
|
||||||
|
CloseContextResponse close_context_response;
|
||||||
|
TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
|
||||||
|
&close_context_response));
|
||||||
|
}
|
||||||
|
|
||||||
// Test requests sent to the eager service on master.
|
// Test requests sent to the eager service on master.
|
||||||
TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
|
TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
|
||||||
tensorflow::Rendezvous* rendezvous =
|
tensorflow::Rendezvous* rendezvous =
|
||||||
|
@ -25,6 +25,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace eager {
|
namespace eager {
|
||||||
@ -290,6 +292,102 @@ void RemoteCopyNode::StartRecv(StatusCallback done) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
|
||||||
|
const Device* target_device, EagerContext* ctx,
|
||||||
|
SendPackedHandleOp* op) {
|
||||||
|
op->set_op_id(op_id);
|
||||||
|
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
|
||||||
|
TensorHandle* h = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
|
||||||
|
if (h->Type() == TensorHandle::LOCAL) {
|
||||||
|
// AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
|
||||||
|
// copy it to the CPU before copying it out.
|
||||||
|
Tensor tensor;
|
||||||
|
TF_RETURN_IF_ERROR(h->CopyToDevice(*ctx, ctx->HostCPU(), &tensor));
|
||||||
|
auto* local_handle = op->add_handles()->mutable_local_handle();
|
||||||
|
local_handle->set_device(h->op_device() ? h->op_device()->name()
|
||||||
|
: ctx->HostCPU()->name());
|
||||||
|
tensor.AsProtoTensorContent(local_handle->mutable_tensor());
|
||||||
|
} else if (h->Type() == TensorHandle::REMOTE) {
|
||||||
|
// Only serialize the resource dtype and shape of the first handle, since
|
||||||
|
// all handles are of the same resource dtype and shape.
|
||||||
|
Device* src_device = absl::get<Device*>(h->device());
|
||||||
|
const bool serialize_resource_dtype_and_shape =
|
||||||
|
(i == 0) && (h->dtype == DT_RESOURCE) &&
|
||||||
|
(ctx->OnSameTask(src_device, target_device));
|
||||||
|
TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||||
|
h, op->add_handles()->mutable_remote_handle(), src_device,
|
||||||
|
absl::get<Device*>(h->DeviceOrHostCPU(*ctx))->name(),
|
||||||
|
serialize_resource_dtype_and_shape));
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("Nested packed handles are not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) {
|
||||||
|
Status s;
|
||||||
|
const uint64 context_view_id = ctx_->GetContextViewId();
|
||||||
|
if (!send_device_->IsLocal()) {
|
||||||
|
s = errors::InvalidArgument(
|
||||||
|
"Copy a packed handle from a remote device is not supported");
|
||||||
|
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
EnqueueRequest request;
|
||||||
|
uint64 context_id = ctx_->GetContextId();
|
||||||
|
request.set_context_id(context_id);
|
||||||
|
s = SerializePackedHandle(recv_op_id_, src_, recv_device_, ctx_,
|
||||||
|
request.add_queue()->mutable_send_packed_handle());
|
||||||
|
if (!s.ok()) {
|
||||||
|
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape shape;
|
||||||
|
s = src_->Shape(&shape);
|
||||||
|
if (!s.ok()) {
|
||||||
|
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
captured_state_->SetSrcShape(shape);
|
||||||
|
|
||||||
|
core::RefCountPtr<eager::EagerClient> eager_client;
|
||||||
|
s = ctx_->GetClient(recv_device_, &eager_client);
|
||||||
|
if (!s.ok()) {
|
||||||
|
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
EnqueueResponse* response = new EnqueueResponse;
|
||||||
|
Device* recv_device = recv_device_;
|
||||||
|
const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
|
||||||
|
eager_client->StreamingEnqueueAsync(
|
||||||
|
&request, response,
|
||||||
|
[captured_state, response, recv_device, context_view_id,
|
||||||
|
done](const Status& s) {
|
||||||
|
if (s.ok()) {
|
||||||
|
Status status = captured_state->dst()->SetRemoteShape(
|
||||||
|
captured_state->GetSrcShape(), recv_device, context_view_id);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Ignoring an error encountered when setting remote "
|
||||||
|
"shape of tensor received by SendPackedHadnle rpc: "
|
||||||
|
<< status.ToString();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
|
||||||
|
}
|
||||||
|
done(s);
|
||||||
|
delete response;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
|
void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
|
||||||
Status s;
|
Status s;
|
||||||
EnqueueRequest request;
|
EnqueueRequest request;
|
||||||
@ -351,7 +449,11 @@ Status RemoteCopyNode::Prepare() {
|
|||||||
|
|
||||||
void RemoteCopyNode::RunAsync(StatusCallback done) {
|
void RemoteCopyNode::RunAsync(StatusCallback done) {
|
||||||
started_ = true;
|
started_ = true;
|
||||||
if (ctx_->UseSendTensorRPC() && send_device_->IsLocal() &&
|
if (src_->Type() == TensorHandle::PACKED) {
|
||||||
|
return StartSendPackedHandle(std::move(done));
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((ctx_->UseSendTensorRPC()) && send_device_->IsLocal() &&
|
||||||
!recv_device_->IsLocal()) {
|
!recv_device_->IsLocal()) {
|
||||||
return StartRemoteSendTensor(std::move(done));
|
return StartRemoteSendTensor(std::move(done));
|
||||||
}
|
}
|
||||||
|
@ -121,6 +121,9 @@ class RemoteCopyNode : public AsyncEagerNode {
|
|||||||
// SendTensor RPC *on the receiver*.
|
// SendTensor RPC *on the receiver*.
|
||||||
void StartRemoteSendTensor(StatusCallback done);
|
void StartRemoteSendTensor(StatusCallback done);
|
||||||
|
|
||||||
|
// Send a local packed TensorHandle to a remote device.
|
||||||
|
void StartSendPackedHandle(StatusCallback done);
|
||||||
|
|
||||||
// State that is captured by Send and/or Recv callbacks (depending on which
|
// State that is captured by Send and/or Recv callbacks (depending on which
|
||||||
// one(s) is remote) and outlives this node in the case of remote->remote
|
// one(s) is remote) and outlives this node in the case of remote->remote
|
||||||
// copy.
|
// copy.
|
||||||
|
@ -69,6 +69,7 @@ message QueueItem {
|
|||||||
// enqueued in streaming call. Request with this item type waits for pending
|
// enqueued in streaming call. Request with this item type waits for pending
|
||||||
// nodes to finish on the remote executor and report status.
|
// nodes to finish on the remote executor and report status.
|
||||||
SyncRemoteExecutorForStream sync_remote_executor_for_stream = 6;
|
SyncRemoteExecutorForStream sync_remote_executor_for_stream = 6;
|
||||||
|
SendPackedHandleOp send_packed_handle = 7;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,6 +239,27 @@ message SendTensorOp {
|
|||||||
string device_name = 3;
|
string device_name = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send a packed TensorHandle to a remote worker.
|
||||||
|
message SendPackedHandleOp {
|
||||||
|
// Op id of the remote packed TensorHandle.
|
||||||
|
int64 op_id = 1;
|
||||||
|
|
||||||
|
message LocalTensorHandle {
|
||||||
|
TensorProto tensor = 1;
|
||||||
|
// Device where the tensor is produced.
|
||||||
|
string device = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Handle {
|
||||||
|
oneof item {
|
||||||
|
LocalTensorHandle local_handle = 1;
|
||||||
|
RemoteTensorHandle remote_handle = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
repeated Handle handles = 2;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
//
|
//
|
||||||
// Eager Service defines a TensorFlow service that executes operations eagerly
|
// Eager Service defines a TensorFlow service that executes operations eagerly
|
||||||
|
Loading…
Reference in New Issue
Block a user