diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index 12c63675c87..9dc18c7a6f1 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -434,7 +434,7 @@ string AddVariablesFunction() { return def.SerializeAsString(); } -TEST(CAPI, TestFunctionWithPackedInput) { +void TestFunctionWithPackedInput(const bool remote) { tensorflow::ServerDef server_def = GetServerDef(3); // This server def has the task index set to 0. @@ -502,6 +502,10 @@ TEST(CAPI, TestFunctionWithPackedInput) { ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_OpAddInput(func, packed_handle, status); ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + if (remote) { + TFE_OpSetDevice(func, task1_name, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } TFE_TensorHandle* retvals[1] = {nullptr}; int num_retvals = 1; @@ -537,6 +541,14 @@ TEST(CAPI, TestFunctionWithPackedInput) { worker_server2.release(); } +TEST(CAPI, TestLocalFunctionWithPackedInput) { + TestFunctionWithPackedInput(/*remote=*/false); +} + +TEST(CAPI, TestRemoteFunctionWithPackedInput) { + TestFunctionWithPackedInput(/*remote=*/true); +} + void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index dfe3e4a1426..49fa69e2185 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -318,17 +318,14 @@ TensorHandle::TensorHandle(Device* d, Device* op_device, } Status TensorHandle::CreatePackedHandle(std::vector&& handles, + const tensorflow::DataType dtype, + const tensorflow::TensorShape& shape, EagerContext* ctx, TensorHandle** packed_handle) { if (handles.empty()) { return errors::InvalidArgument("Handles should not be empty."); } - // Get the dtype and shape from the fisrt handle since all handles have the - // same dtype and shape. - tensorflow::DataType dtype = handles.at(0)->dtype; - tensorflow::TensorShape shape; - TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape)); ResourceHandleInfo resource_handle_info; if (dtype == DT_RESOURCE) { TF_RETURN_IF_ERROR( @@ -360,6 +357,22 @@ Status TensorHandle::CreatePackedHandle(std::vector&& handles, return Status::OK(); } +Status TensorHandle::CreatePackedHandle(std::vector&& handles, + EagerContext* ctx, + TensorHandle** packed_handle) { + if (handles.empty()) { + return errors::InvalidArgument("Handles should not be empty."); + } + + // Get the dtype and shape from the fisrt handle since all handles have the + // same dtype and shape. + tensorflow::DataType dtype = handles.at(0)->dtype; + tensorflow::TensorShape shape; + TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape)); + return CreatePackedHandle(std::move(handles), dtype, shape, ctx, + packed_handle); +} + TensorHandle::TensorHandle(std::vector&& handles, Device* device, const tensorflow::DataType dtype, const tensorflow::TensorShape& shape, diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 25d7fea3200..6f9ee565c73 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -91,6 +91,11 @@ class TensorHandle : public AbstractTensorHandleInterface, // Create a handle which packs the given handles of the same dtype and shape. // If handles are on different devices, assign the packed handle to a // CompositeDevice. + static Status CreatePackedHandle(std::vector&& handles, + const tensorflow::DataType dtype, + const tensorflow::TensorShape& shape, + EagerContext* ctx, + TensorHandle** packed_handle); static Status CreatePackedHandle(std::vector&& handles, EagerContext* ctx, TensorHandle** packed_handle); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 95131150d3d..6dc03cbc527 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -524,6 +524,8 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request, s = context->Context()->Executor().AddOrExecute(std::move(node)); } else if (item.has_send_tensor()) { s = SendTensor(item.send_tensor(), context->Context()); + } else if (item.has_send_packed_handle()) { + s = SendPackedHandle(item.send_packed_handle(), context->Context()); } else if (item.has_register_function()) { s = RegisterFunction(item.register_function(), context->Context()); } else if (item.has_cleanup_function()) { @@ -643,6 +645,52 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor, return Status::OK(); } +Status EagerServiceImpl::SendPackedHandle( + const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) { + if (send_packed_handle.handles().empty()) { + return errors::InvalidArgument("Handles should not be empty."); + } + + std::vector handles; + handles.resize(send_packed_handle.handles_size()); + for (int i = 0; i < send_packed_handle.handles_size(); ++i) { + const auto& item = send_packed_handle.handles(i); + if (item.has_local_handle()) { + Tensor tensor; + if (!ParseTensorProtoToTensor(item.local_handle().tensor(), &tensor)) { + return errors::InvalidArgument( + "Invalid TensorProto: ", + item.local_handle().tensor().DebugString()); + } + Device* op_device = nullptr; + TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName( + item.local_handle().device().c_str(), &op_device)); + handles[i] = TensorHandle::CreateLocalHandle( + std::move(tensor), /*d=*/nullptr, op_device, eager_context); + } else { + TF_RETURN_IF_ERROR( + eager_context->RemoteMgr()->DeserializeRemoteTensorHandle( + item.remote_handle(), &handles[i])); + } + } + + tensorflow::TensorHandle* packed_handle = nullptr; + std::vector handles_to_pack = handles; + // Create a unshaped packed TensorHandle. + TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle( + std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(), + eager_context, &packed_handle)); + + for (auto* h : handles) { + // Unref handle since it has a ref in the packed handle now. + h->Unref(); + } + + eager_context->RemoteMgr()->AddOperationOutputs({packed_handle}, + send_packed_handle.op_id()); + return Status::OK(); +} + tensorflow::Status EagerServiceImpl::GetServerContext( uint64 context_id, ServerContext** server_context) { tf_shared_lock l(contexts_mu_); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index 06d4c36b61c..1e4d36ccf9f 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -212,6 +212,8 @@ class EagerServiceImpl { QueueResponse* queue_response); Status SendTensor(const SendTensorOp& send_tensor, EagerContext* eager_context); + Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle, + EagerContext* eager_context); Status RegisterFunction(const RegisterFunctionOp& register_function, EagerContext* eager_context); Status CleanupFunction(const CleanupFunctionOp& cleanup_function); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 9930bb86e6b..23bf324b80f 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -881,6 +881,108 @@ TEST_F(EagerServiceImplTest, SendTensorTest) { &close_context_response)); } +// Test serializes and sends a pack TensorHandle. +TEST_F(EagerServiceImplTest, SendPackedHandleTest) { + TestEagerServiceImpl eager_service_impl(&worker_env_); + + const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0"; + const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0"; + const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0"; + + uint64 context_id = random::New64(); + CreateContextRequest request; + auto* server_def = request.mutable_server_def(); + server_def->set_job_name("localhost"); + server_def->set_task_index(0); + request.add_cluster_device_attributes()->set_name(device0); + request.add_cluster_device_attributes()->set_name(device1); + request.add_cluster_device_attributes()->set_name(device2); + request.set_context_id(context_id); + CreateContextResponse response; + + TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); + + EnqueueRequest remote_enqueue_request; + remote_enqueue_request.set_context_id(context_id); + EnqueueResponse remote_enqueue_response; + + // Copy a tensor to device0 + auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor(); + send_tensor->set_op_id(1); + SetTensorProto(send_tensor->add_tensors()); + + // Copy a packed handle to device0 + auto* send_packed_handle = + remote_enqueue_request.add_queue()->mutable_send_packed_handle(); + send_packed_handle->set_op_id(3); + RemoteTensorHandle* remote_handle = + send_packed_handle->add_handles()->mutable_remote_handle(); + remote_handle->set_op_id(send_tensor->op_id()); + remote_handle->set_output_num(0); + remote_handle->set_op_device(device0); + remote_handle->set_device(device0); + + SendPackedHandleOp::LocalTensorHandle* lcoal_handle = + send_packed_handle->add_handles()->mutable_local_handle(); + SetTensorProto(lcoal_handle->mutable_tensor()); + lcoal_handle->set_device(device1); + + remote_handle = send_packed_handle->add_handles()->mutable_remote_handle(); + remote_handle->set_op_id(2); + remote_handle->set_output_num(5); + remote_handle->set_op_device(device2); + remote_handle->set_device(device2); + + TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request, + &remote_enqueue_response)); + + tensorflow::TensorHandle* packed_handle; + TF_ASSERT_OK(eager_service_impl.GetTensorHandle( + context_id, RemoteTensorHandleInternal(3, 0), &packed_handle)); + + EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED); + EXPECT_EQ(packed_handle->NumPackedHandles(), 3); + + TensorHandle* handle0 = nullptr; + TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0)); + EXPECT_EQ(handle0->Type(), TensorHandle::LOCAL); + EXPECT_EQ(handle0->op_device()->name(), device0); + const Tensor* t0 = nullptr; + TF_ASSERT_OK(handle0->Tensor(&t0)); + auto actual = t0->flat(); + EXPECT_EQ(4, actual.size()); + EXPECT_EQ(1.0, actual(0)); + EXPECT_EQ(2.0, actual(1)); + EXPECT_EQ(3.0, actual(2)); + EXPECT_EQ(4.0, actual(3)); + + TensorHandle* handle1 = nullptr; + TF_ASSERT_OK(packed_handle->ExtractPackedHandle(1, &handle1)); + EXPECT_EQ(handle1->Type(), TensorHandle::LOCAL); + EXPECT_EQ(handle1->op_device()->name(), device1); + const Tensor* t1 = nullptr; + TF_ASSERT_OK(handle0->Tensor(&t1)); + EXPECT_EQ(t1, t0); + + TensorHandle* handle2 = nullptr; + TF_ASSERT_OK(packed_handle->ExtractPackedHandle(2, &handle2)); + EXPECT_EQ(handle2->Type(), TensorHandle::REMOTE); + EXPECT_EQ(handle2->op_device()->name(), device2); + int64 op_id; + int32 output_num; + TF_ASSERT_OK(handle2->RemoteAddressUntilReady( + absl::get(handle2->device()), &op_id, &output_num)); + EXPECT_EQ(op_id, 2); + EXPECT_EQ(output_num, 5); + + CloseContextRequest close_context_request; + close_context_request.set_context_id(context_id); + close_context_request.set_context_view_id(0); + CloseContextResponse close_context_response; + TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request, + &close_context_response)); +} + // Test requests sent to the eager service on master. TEST_F(EagerServiceImplTest, RequestsToMasterTest) { tensorflow::Rendezvous* rendezvous = diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index b281bcef2b3..5d0793b258c 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/protobuf.h" namespace tensorflow { namespace eager { @@ -290,6 +292,102 @@ void RemoteCopyNode::StartRecv(StatusCallback done) { } } +Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle, + const Device* target_device, EagerContext* ctx, + SendPackedHandleOp* op) { + op->set_op_id(op_id); + for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) { + TensorHandle* h = nullptr; + TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h)); + if (h->Type() == TensorHandle::LOCAL) { + // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence + // copy it to the CPU before copying it out. + Tensor tensor; + TF_RETURN_IF_ERROR(h->CopyToDevice(*ctx, ctx->HostCPU(), &tensor)); + auto* local_handle = op->add_handles()->mutable_local_handle(); + local_handle->set_device(h->op_device() ? h->op_device()->name() + : ctx->HostCPU()->name()); + tensor.AsProtoTensorContent(local_handle->mutable_tensor()); + } else if (h->Type() == TensorHandle::REMOTE) { + // Only serialize the resource dtype and shape of the first handle, since + // all handles are of the same resource dtype and shape. + Device* src_device = absl::get(h->device()); + const bool serialize_resource_dtype_and_shape = + (i == 0) && (h->dtype == DT_RESOURCE) && + (ctx->OnSameTask(src_device, target_device)); + TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle( + h, op->add_handles()->mutable_remote_handle(), src_device, + absl::get(h->DeviceOrHostCPU(*ctx))->name(), + serialize_resource_dtype_and_shape)); + } else { + return errors::InvalidArgument("Nested packed handles are not supported"); + } + } + return Status::OK(); +} + +void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) { + Status s; + const uint64 context_view_id = ctx_->GetContextViewId(); + if (!send_device_->IsLocal()) { + s = errors::InvalidArgument( + "Copy a packed handle from a remote device is not supported"); + captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id); + done(s); + return; + } + + EnqueueRequest request; + uint64 context_id = ctx_->GetContextId(); + request.set_context_id(context_id); + s = SerializePackedHandle(recv_op_id_, src_, recv_device_, ctx_, + request.add_queue()->mutable_send_packed_handle()); + if (!s.ok()) { + captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id); + done(s); + return; + } + + TensorShape shape; + s = src_->Shape(&shape); + if (!s.ok()) { + captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id); + done(s); + return; + } + captured_state_->SetSrcShape(shape); + + core::RefCountPtr eager_client; + s = ctx_->GetClient(recv_device_, &eager_client); + if (!s.ok()) { + captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id); + done(s); + return; + } + + EnqueueResponse* response = new EnqueueResponse; + Device* recv_device = recv_device_; + const std::shared_ptr& captured_state = captured_state_; + eager_client->StreamingEnqueueAsync( + &request, response, + [captured_state, response, recv_device, context_view_id, + done](const Status& s) { + if (s.ok()) { + Status status = captured_state->dst()->SetRemoteShape( + captured_state->GetSrcShape(), recv_device, context_view_id); + if (!status.ok()) { + LOG(ERROR) << "Ignoring an error encountered when setting remote " + "shape of tensor received by SendPackedHadnle rpc: " + << status.ToString(); + } + } else { + captured_state->dst()->PoisonRemote(s, recv_device, context_view_id); + } + done(s); + delete response; + }); +} + void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) { Status s; EnqueueRequest request; @@ -351,7 +449,11 @@ Status RemoteCopyNode::Prepare() { void RemoteCopyNode::RunAsync(StatusCallback done) { started_ = true; - if (ctx_->UseSendTensorRPC() && send_device_->IsLocal() && + if (src_->Type() == TensorHandle::PACKED) { + return StartSendPackedHandle(std::move(done)); + } + + if ((ctx_->UseSendTensorRPC()) && send_device_->IsLocal() && !recv_device_->IsLocal()) { return StartRemoteSendTensor(std::move(done)); } diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.h b/tensorflow/core/distributed_runtime/eager/remote_copy_node.h index a527cd47127..7816a24ed33 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.h @@ -121,6 +121,9 @@ class RemoteCopyNode : public AsyncEagerNode { // SendTensor RPC *on the receiver*. void StartRemoteSendTensor(StatusCallback done); + // Send a local packed TensorHandle to a remote device. + void StartSendPackedHandle(StatusCallback done); + // State that is captured by Send and/or Recv callbacks (depending on which // one(s) is remote) and outlives this node in the case of remote->remote // copy. diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index e9e21777d3f..3fe2bd486ba 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -69,6 +69,7 @@ message QueueItem { // enqueued in streaming call. Request with this item type waits for pending // nodes to finish on the remote executor and report status. SyncRemoteExecutorForStream sync_remote_executor_for_stream = 6; + SendPackedHandleOp send_packed_handle = 7; } } @@ -238,6 +239,27 @@ message SendTensorOp { string device_name = 3; } +// Send a packed TensorHandle to a remote worker. +message SendPackedHandleOp { + // Op id of the remote packed TensorHandle. + int64 op_id = 1; + + message LocalTensorHandle { + TensorProto tensor = 1; + // Device where the tensor is produced. + string device = 2; + } + + message Handle { + oneof item { + LocalTensorHandle local_handle = 1; + RemoteTensorHandle remote_handle = 2; + } + } + + repeated Handle handles = 2; +} + //////////////////////////////////////////////////////////////////////////////// // // Eager Service defines a TensorFlow service that executes operations eagerly