Support running a remote function with packed input handles.

- Support copying a packed TensorHandle from a client to a remote worker.

PiperOrigin-RevId: 311404609
Change-Id: Iadf2c7793dc3631f7be05de611d059733bbfdd63
This commit is contained in:
Yujing Zhang 2020-05-13 14:27:47 -07:00 committed by TensorFlower Gardener
parent 4f6a3a4db0
commit 8588e0aab8
9 changed files with 316 additions and 7 deletions

View File

@ -434,7 +434,7 @@ string AddVariablesFunction() {
return def.SerializeAsString();
}
TEST(CAPI, TestFunctionWithPackedInput) {
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -502,6 +502,10 @@ TEST(CAPI, TestFunctionWithPackedInput) {
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
if (remote) {
TFE_OpSetDevice(func, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
@ -537,6 +541,14 @@ TEST(CAPI, TestFunctionWithPackedInput) {
worker_server2.release();
}
TEST(CAPI, TestLocalFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/false);
}
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);

View File

@ -318,17 +318,14 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
}
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
const tensorflow::DataType dtype,
const tensorflow::TensorShape& shape,
EagerContext* ctx,
TensorHandle** packed_handle) {
if (handles.empty()) {
return errors::InvalidArgument("Handles should not be empty.");
}
// Get the dtype and shape from the fisrt handle since all handles have the
// same dtype and shape.
tensorflow::DataType dtype = handles.at(0)->dtype;
tensorflow::TensorShape shape;
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
ResourceHandleInfo resource_handle_info;
if (dtype == DT_RESOURCE) {
TF_RETURN_IF_ERROR(
@ -360,6 +357,22 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
return Status::OK();
}
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
EagerContext* ctx,
TensorHandle** packed_handle) {
if (handles.empty()) {
return errors::InvalidArgument("Handles should not be empty.");
}
// Get the dtype and shape from the fisrt handle since all handles have the
// same dtype and shape.
tensorflow::DataType dtype = handles.at(0)->dtype;
tensorflow::TensorShape shape;
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
return CreatePackedHandle(std::move(handles), dtype, shape, ctx,
packed_handle);
}
TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
const tensorflow::DataType dtype,
const tensorflow::TensorShape& shape,

View File

@ -91,6 +91,11 @@ class TensorHandle : public AbstractTensorHandleInterface,
// Create a handle which packs the given handles of the same dtype and shape.
// If handles are on different devices, assign the packed handle to a
// CompositeDevice.
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
const tensorflow::DataType dtype,
const tensorflow::TensorShape& shape,
EagerContext* ctx,
TensorHandle** packed_handle);
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
EagerContext* ctx,
TensorHandle** packed_handle);

View File

@ -524,6 +524,8 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
s = context->Context()->Executor().AddOrExecute(std::move(node));
} else if (item.has_send_tensor()) {
s = SendTensor(item.send_tensor(), context->Context());
} else if (item.has_send_packed_handle()) {
s = SendPackedHandle(item.send_packed_handle(), context->Context());
} else if (item.has_register_function()) {
s = RegisterFunction(item.register_function(), context->Context());
} else if (item.has_cleanup_function()) {
@ -643,6 +645,52 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
return Status::OK();
}
Status EagerServiceImpl::SendPackedHandle(
const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) {
if (send_packed_handle.handles().empty()) {
return errors::InvalidArgument("Handles should not be empty.");
}
std::vector<tensorflow::TensorHandle*> handles;
handles.resize(send_packed_handle.handles_size());
for (int i = 0; i < send_packed_handle.handles_size(); ++i) {
const auto& item = send_packed_handle.handles(i);
if (item.has_local_handle()) {
Tensor tensor;
if (!ParseTensorProtoToTensor(item.local_handle().tensor(), &tensor)) {
return errors::InvalidArgument(
"Invalid TensorProto: ",
item.local_handle().tensor().DebugString());
}
Device* op_device = nullptr;
TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
item.local_handle().device().c_str(), &op_device));
handles[i] = TensorHandle::CreateLocalHandle(
std::move(tensor), /*d=*/nullptr, op_device, eager_context);
} else {
TF_RETURN_IF_ERROR(
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
item.remote_handle(), &handles[i]));
}
}
tensorflow::TensorHandle* packed_handle = nullptr;
std::vector<tensorflow::TensorHandle*> handles_to_pack = handles;
// Create a unshaped packed TensorHandle.
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
eager_context, &packed_handle));
for (auto* h : handles) {
// Unref handle since it has a ref in the packed handle now.
h->Unref();
}
eager_context->RemoteMgr()->AddOperationOutputs({packed_handle},
send_packed_handle.op_id());
return Status::OK();
}
tensorflow::Status EagerServiceImpl::GetServerContext(
uint64 context_id, ServerContext** server_context) {
tf_shared_lock l(contexts_mu_);

View File

@ -212,6 +212,8 @@ class EagerServiceImpl {
QueueResponse* queue_response);
Status SendTensor(const SendTensorOp& send_tensor,
EagerContext* eager_context);
Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle,
EagerContext* eager_context);
Status RegisterFunction(const RegisterFunctionOp& register_function,
EagerContext* eager_context);
Status CleanupFunction(const CleanupFunctionOp& cleanup_function);

View File

@ -881,6 +881,108 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
&close_context_response));
}
// Test serializes and sends a pack TensorHandle.
TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
TestEagerServiceImpl eager_service_impl(&worker_env_);
const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
uint64 context_id = random::New64();
CreateContextRequest request;
auto* server_def = request.mutable_server_def();
server_def->set_job_name("localhost");
server_def->set_task_index(0);
request.add_cluster_device_attributes()->set_name(device0);
request.add_cluster_device_attributes()->set_name(device1);
request.add_cluster_device_attributes()->set_name(device2);
request.set_context_id(context_id);
CreateContextResponse response;
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
EnqueueRequest remote_enqueue_request;
remote_enqueue_request.set_context_id(context_id);
EnqueueResponse remote_enqueue_response;
// Copy a tensor to device0
auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
send_tensor->set_op_id(1);
SetTensorProto(send_tensor->add_tensors());
// Copy a packed handle to device0
auto* send_packed_handle =
remote_enqueue_request.add_queue()->mutable_send_packed_handle();
send_packed_handle->set_op_id(3);
RemoteTensorHandle* remote_handle =
send_packed_handle->add_handles()->mutable_remote_handle();
remote_handle->set_op_id(send_tensor->op_id());
remote_handle->set_output_num(0);
remote_handle->set_op_device(device0);
remote_handle->set_device(device0);
SendPackedHandleOp::LocalTensorHandle* lcoal_handle =
send_packed_handle->add_handles()->mutable_local_handle();
SetTensorProto(lcoal_handle->mutable_tensor());
lcoal_handle->set_device(device1);
remote_handle = send_packed_handle->add_handles()->mutable_remote_handle();
remote_handle->set_op_id(2);
remote_handle->set_output_num(5);
remote_handle->set_op_device(device2);
remote_handle->set_device(device2);
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
&remote_enqueue_response));
tensorflow::TensorHandle* packed_handle;
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
context_id, RemoteTensorHandleInternal(3, 0), &packed_handle));
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
TensorHandle* handle0 = nullptr;
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
EXPECT_EQ(handle0->Type(), TensorHandle::LOCAL);
EXPECT_EQ(handle0->op_device()->name(), device0);
const Tensor* t0 = nullptr;
TF_ASSERT_OK(handle0->Tensor(&t0));
auto actual = t0->flat<float>();
EXPECT_EQ(4, actual.size());
EXPECT_EQ(1.0, actual(0));
EXPECT_EQ(2.0, actual(1));
EXPECT_EQ(3.0, actual(2));
EXPECT_EQ(4.0, actual(3));
TensorHandle* handle1 = nullptr;
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(1, &handle1));
EXPECT_EQ(handle1->Type(), TensorHandle::LOCAL);
EXPECT_EQ(handle1->op_device()->name(), device1);
const Tensor* t1 = nullptr;
TF_ASSERT_OK(handle0->Tensor(&t1));
EXPECT_EQ(t1, t0);
TensorHandle* handle2 = nullptr;
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(2, &handle2));
EXPECT_EQ(handle2->Type(), TensorHandle::REMOTE);
EXPECT_EQ(handle2->op_device()->name(), device2);
int64 op_id;
int32 output_num;
TF_ASSERT_OK(handle2->RemoteAddressUntilReady(
absl::get<Device*>(handle2->device()), &op_id, &output_num));
EXPECT_EQ(op_id, 2);
EXPECT_EQ(output_num, 5);
CloseContextRequest close_context_request;
close_context_request.set_context_id(context_id);
close_context_request.set_context_view_id(0);
CloseContextResponse close_context_response;
TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
&close_context_response));
}
// Test requests sent to the eager service on master.
TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
tensorflow::Rendezvous* rendezvous =

View File

@ -25,6 +25,8 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
namespace eager {
@ -290,6 +292,102 @@ void RemoteCopyNode::StartRecv(StatusCallback done) {
}
}
Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
const Device* target_device, EagerContext* ctx,
SendPackedHandleOp* op) {
op->set_op_id(op_id);
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
TensorHandle* h = nullptr;
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
if (h->Type() == TensorHandle::LOCAL) {
// AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
// copy it to the CPU before copying it out.
Tensor tensor;
TF_RETURN_IF_ERROR(h->CopyToDevice(*ctx, ctx->HostCPU(), &tensor));
auto* local_handle = op->add_handles()->mutable_local_handle();
local_handle->set_device(h->op_device() ? h->op_device()->name()
: ctx->HostCPU()->name());
tensor.AsProtoTensorContent(local_handle->mutable_tensor());
} else if (h->Type() == TensorHandle::REMOTE) {
// Only serialize the resource dtype and shape of the first handle, since
// all handles are of the same resource dtype and shape.
Device* src_device = absl::get<Device*>(h->device());
const bool serialize_resource_dtype_and_shape =
(i == 0) && (h->dtype == DT_RESOURCE) &&
(ctx->OnSameTask(src_device, target_device));
TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
h, op->add_handles()->mutable_remote_handle(), src_device,
absl::get<Device*>(h->DeviceOrHostCPU(*ctx))->name(),
serialize_resource_dtype_and_shape));
} else {
return errors::InvalidArgument("Nested packed handles are not supported");
}
}
return Status::OK();
}
void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) {
Status s;
const uint64 context_view_id = ctx_->GetContextViewId();
if (!send_device_->IsLocal()) {
s = errors::InvalidArgument(
"Copy a packed handle from a remote device is not supported");
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
done(s);
return;
}
EnqueueRequest request;
uint64 context_id = ctx_->GetContextId();
request.set_context_id(context_id);
s = SerializePackedHandle(recv_op_id_, src_, recv_device_, ctx_,
request.add_queue()->mutable_send_packed_handle());
if (!s.ok()) {
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
done(s);
return;
}
TensorShape shape;
s = src_->Shape(&shape);
if (!s.ok()) {
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
done(s);
return;
}
captured_state_->SetSrcShape(shape);
core::RefCountPtr<eager::EagerClient> eager_client;
s = ctx_->GetClient(recv_device_, &eager_client);
if (!s.ok()) {
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
done(s);
return;
}
EnqueueResponse* response = new EnqueueResponse;
Device* recv_device = recv_device_;
const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
eager_client->StreamingEnqueueAsync(
&request, response,
[captured_state, response, recv_device, context_view_id,
done](const Status& s) {
if (s.ok()) {
Status status = captured_state->dst()->SetRemoteShape(
captured_state->GetSrcShape(), recv_device, context_view_id);
if (!status.ok()) {
LOG(ERROR) << "Ignoring an error encountered when setting remote "
"shape of tensor received by SendPackedHadnle rpc: "
<< status.ToString();
}
} else {
captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
}
done(s);
delete response;
});
}
void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
Status s;
EnqueueRequest request;
@ -351,7 +449,11 @@ Status RemoteCopyNode::Prepare() {
void RemoteCopyNode::RunAsync(StatusCallback done) {
started_ = true;
if (ctx_->UseSendTensorRPC() && send_device_->IsLocal() &&
if (src_->Type() == TensorHandle::PACKED) {
return StartSendPackedHandle(std::move(done));
}
if ((ctx_->UseSendTensorRPC()) && send_device_->IsLocal() &&
!recv_device_->IsLocal()) {
return StartRemoteSendTensor(std::move(done));
}

View File

@ -121,6 +121,9 @@ class RemoteCopyNode : public AsyncEagerNode {
// SendTensor RPC *on the receiver*.
void StartRemoteSendTensor(StatusCallback done);
// Send a local packed TensorHandle to a remote device.
void StartSendPackedHandle(StatusCallback done);
// State that is captured by Send and/or Recv callbacks (depending on which
// one(s) is remote) and outlives this node in the case of remote->remote
// copy.

View File

@ -69,6 +69,7 @@ message QueueItem {
// enqueued in streaming call. Request with this item type waits for pending
// nodes to finish on the remote executor and report status.
SyncRemoteExecutorForStream sync_remote_executor_for_stream = 6;
SendPackedHandleOp send_packed_handle = 7;
}
}
@ -238,6 +239,27 @@ message SendTensorOp {
string device_name = 3;
}
// Send a packed TensorHandle to a remote worker.
message SendPackedHandleOp {
// Op id of the remote packed TensorHandle.
int64 op_id = 1;
message LocalTensorHandle {
TensorProto tensor = 1;
// Device where the tensor is produced.
string device = 2;
}
message Handle {
oneof item {
LocalTensorHandle local_handle = 1;
RemoteTensorHandle remote_handle = 2;
}
}
repeated Handle handles = 2;
}
////////////////////////////////////////////////////////////////////////////////
//
// Eager Service defines a TensorFlow service that executes operations eagerly