Support running a remote function with packed input handles.
- Support copying a packed TensorHandle from a client to a remote worker. PiperOrigin-RevId: 311404609 Change-Id: Iadf2c7793dc3631f7be05de611d059733bbfdd63
This commit is contained in:
parent
4f6a3a4db0
commit
8588e0aab8
@ -434,7 +434,7 @@ string AddVariablesFunction() {
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestFunctionWithPackedInput) {
|
||||
void TestFunctionWithPackedInput(const bool remote) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -502,6 +502,10 @@ TEST(CAPI, TestFunctionWithPackedInput) {
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, packed_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(func, task1_name, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
@ -537,6 +541,14 @@ TEST(CAPI, TestFunctionWithPackedInput) {
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestLocalFunctionWithPackedInput) {
|
||||
TestFunctionWithPackedInput(/*remote=*/false);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
|
||||
TestFunctionWithPackedInput(/*remote=*/true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
|
@ -318,17 +318,14 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
|
||||
}
|
||||
|
||||
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
const tensorflow::DataType dtype,
|
||||
const tensorflow::TensorShape& shape,
|
||||
EagerContext* ctx,
|
||||
TensorHandle** packed_handle) {
|
||||
if (handles.empty()) {
|
||||
return errors::InvalidArgument("Handles should not be empty.");
|
||||
}
|
||||
|
||||
// Get the dtype and shape from the fisrt handle since all handles have the
|
||||
// same dtype and shape.
|
||||
tensorflow::DataType dtype = handles.at(0)->dtype;
|
||||
tensorflow::TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
|
||||
ResourceHandleInfo resource_handle_info;
|
||||
if (dtype == DT_RESOURCE) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -360,6 +357,22 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
EagerContext* ctx,
|
||||
TensorHandle** packed_handle) {
|
||||
if (handles.empty()) {
|
||||
return errors::InvalidArgument("Handles should not be empty.");
|
||||
}
|
||||
|
||||
// Get the dtype and shape from the fisrt handle since all handles have the
|
||||
// same dtype and shape.
|
||||
tensorflow::DataType dtype = handles.at(0)->dtype;
|
||||
tensorflow::TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
|
||||
return CreatePackedHandle(std::move(handles), dtype, shape, ctx,
|
||||
packed_handle);
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
|
||||
const tensorflow::DataType dtype,
|
||||
const tensorflow::TensorShape& shape,
|
||||
|
@ -91,6 +91,11 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
||||
// Create a handle which packs the given handles of the same dtype and shape.
|
||||
// If handles are on different devices, assign the packed handle to a
|
||||
// CompositeDevice.
|
||||
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
const tensorflow::DataType dtype,
|
||||
const tensorflow::TensorShape& shape,
|
||||
EagerContext* ctx,
|
||||
TensorHandle** packed_handle);
|
||||
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
EagerContext* ctx,
|
||||
TensorHandle** packed_handle);
|
||||
|
@ -524,6 +524,8 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
|
||||
s = context->Context()->Executor().AddOrExecute(std::move(node));
|
||||
} else if (item.has_send_tensor()) {
|
||||
s = SendTensor(item.send_tensor(), context->Context());
|
||||
} else if (item.has_send_packed_handle()) {
|
||||
s = SendPackedHandle(item.send_packed_handle(), context->Context());
|
||||
} else if (item.has_register_function()) {
|
||||
s = RegisterFunction(item.register_function(), context->Context());
|
||||
} else if (item.has_cleanup_function()) {
|
||||
@ -643,6 +645,52 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerServiceImpl::SendPackedHandle(
|
||||
const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) {
|
||||
if (send_packed_handle.handles().empty()) {
|
||||
return errors::InvalidArgument("Handles should not be empty.");
|
||||
}
|
||||
|
||||
std::vector<tensorflow::TensorHandle*> handles;
|
||||
handles.resize(send_packed_handle.handles_size());
|
||||
for (int i = 0; i < send_packed_handle.handles_size(); ++i) {
|
||||
const auto& item = send_packed_handle.handles(i);
|
||||
if (item.has_local_handle()) {
|
||||
Tensor tensor;
|
||||
if (!ParseTensorProtoToTensor(item.local_handle().tensor(), &tensor)) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid TensorProto: ",
|
||||
item.local_handle().tensor().DebugString());
|
||||
}
|
||||
Device* op_device = nullptr;
|
||||
TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
|
||||
item.local_handle().device().c_str(), &op_device));
|
||||
handles[i] = TensorHandle::CreateLocalHandle(
|
||||
std::move(tensor), /*d=*/nullptr, op_device, eager_context);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
|
||||
item.remote_handle(), &handles[i]));
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle* packed_handle = nullptr;
|
||||
std::vector<tensorflow::TensorHandle*> handles_to_pack = handles;
|
||||
// Create a unshaped packed TensorHandle.
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
|
||||
std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
|
||||
eager_context, &packed_handle));
|
||||
|
||||
for (auto* h : handles) {
|
||||
// Unref handle since it has a ref in the packed handle now.
|
||||
h->Unref();
|
||||
}
|
||||
|
||||
eager_context->RemoteMgr()->AddOperationOutputs({packed_handle},
|
||||
send_packed_handle.op_id());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status EagerServiceImpl::GetServerContext(
|
||||
uint64 context_id, ServerContext** server_context) {
|
||||
tf_shared_lock l(contexts_mu_);
|
||||
|
@ -212,6 +212,8 @@ class EagerServiceImpl {
|
||||
QueueResponse* queue_response);
|
||||
Status SendTensor(const SendTensorOp& send_tensor,
|
||||
EagerContext* eager_context);
|
||||
Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle,
|
||||
EagerContext* eager_context);
|
||||
Status RegisterFunction(const RegisterFunctionOp& register_function,
|
||||
EagerContext* eager_context);
|
||||
Status CleanupFunction(const CleanupFunctionOp& cleanup_function);
|
||||
|
@ -881,6 +881,108 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
|
||||
&close_context_response));
|
||||
}
|
||||
|
||||
// Test serializes and sends a pack TensorHandle.
|
||||
TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
|
||||
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
||||
|
||||
const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
uint64 context_id = random::New64();
|
||||
CreateContextRequest request;
|
||||
auto* server_def = request.mutable_server_def();
|
||||
server_def->set_job_name("localhost");
|
||||
server_def->set_task_index(0);
|
||||
request.add_cluster_device_attributes()->set_name(device0);
|
||||
request.add_cluster_device_attributes()->set_name(device1);
|
||||
request.add_cluster_device_attributes()->set_name(device2);
|
||||
request.set_context_id(context_id);
|
||||
CreateContextResponse response;
|
||||
|
||||
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
||||
|
||||
EnqueueRequest remote_enqueue_request;
|
||||
remote_enqueue_request.set_context_id(context_id);
|
||||
EnqueueResponse remote_enqueue_response;
|
||||
|
||||
// Copy a tensor to device0
|
||||
auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
|
||||
send_tensor->set_op_id(1);
|
||||
SetTensorProto(send_tensor->add_tensors());
|
||||
|
||||
// Copy a packed handle to device0
|
||||
auto* send_packed_handle =
|
||||
remote_enqueue_request.add_queue()->mutable_send_packed_handle();
|
||||
send_packed_handle->set_op_id(3);
|
||||
RemoteTensorHandle* remote_handle =
|
||||
send_packed_handle->add_handles()->mutable_remote_handle();
|
||||
remote_handle->set_op_id(send_tensor->op_id());
|
||||
remote_handle->set_output_num(0);
|
||||
remote_handle->set_op_device(device0);
|
||||
remote_handle->set_device(device0);
|
||||
|
||||
SendPackedHandleOp::LocalTensorHandle* lcoal_handle =
|
||||
send_packed_handle->add_handles()->mutable_local_handle();
|
||||
SetTensorProto(lcoal_handle->mutable_tensor());
|
||||
lcoal_handle->set_device(device1);
|
||||
|
||||
remote_handle = send_packed_handle->add_handles()->mutable_remote_handle();
|
||||
remote_handle->set_op_id(2);
|
||||
remote_handle->set_output_num(5);
|
||||
remote_handle->set_op_device(device2);
|
||||
remote_handle->set_device(device2);
|
||||
|
||||
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
|
||||
&remote_enqueue_response));
|
||||
|
||||
tensorflow::TensorHandle* packed_handle;
|
||||
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
||||
context_id, RemoteTensorHandleInternal(3, 0), &packed_handle));
|
||||
|
||||
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
|
||||
EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
|
||||
|
||||
TensorHandle* handle0 = nullptr;
|
||||
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
|
||||
EXPECT_EQ(handle0->Type(), TensorHandle::LOCAL);
|
||||
EXPECT_EQ(handle0->op_device()->name(), device0);
|
||||
const Tensor* t0 = nullptr;
|
||||
TF_ASSERT_OK(handle0->Tensor(&t0));
|
||||
auto actual = t0->flat<float>();
|
||||
EXPECT_EQ(4, actual.size());
|
||||
EXPECT_EQ(1.0, actual(0));
|
||||
EXPECT_EQ(2.0, actual(1));
|
||||
EXPECT_EQ(3.0, actual(2));
|
||||
EXPECT_EQ(4.0, actual(3));
|
||||
|
||||
TensorHandle* handle1 = nullptr;
|
||||
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(1, &handle1));
|
||||
EXPECT_EQ(handle1->Type(), TensorHandle::LOCAL);
|
||||
EXPECT_EQ(handle1->op_device()->name(), device1);
|
||||
const Tensor* t1 = nullptr;
|
||||
TF_ASSERT_OK(handle0->Tensor(&t1));
|
||||
EXPECT_EQ(t1, t0);
|
||||
|
||||
TensorHandle* handle2 = nullptr;
|
||||
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(2, &handle2));
|
||||
EXPECT_EQ(handle2->Type(), TensorHandle::REMOTE);
|
||||
EXPECT_EQ(handle2->op_device()->name(), device2);
|
||||
int64 op_id;
|
||||
int32 output_num;
|
||||
TF_ASSERT_OK(handle2->RemoteAddressUntilReady(
|
||||
absl::get<Device*>(handle2->device()), &op_id, &output_num));
|
||||
EXPECT_EQ(op_id, 2);
|
||||
EXPECT_EQ(output_num, 5);
|
||||
|
||||
CloseContextRequest close_context_request;
|
||||
close_context_request.set_context_id(context_id);
|
||||
close_context_request.set_context_view_id(0);
|
||||
CloseContextResponse close_context_response;
|
||||
TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
|
||||
&close_context_response));
|
||||
}
|
||||
|
||||
// Test requests sent to the eager service on master.
|
||||
TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
|
||||
tensorflow::Rendezvous* rendezvous =
|
||||
|
@ -25,6 +25,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
@ -290,6 +292,102 @@ void RemoteCopyNode::StartRecv(StatusCallback done) {
|
||||
}
|
||||
}
|
||||
|
||||
Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
|
||||
const Device* target_device, EagerContext* ctx,
|
||||
SendPackedHandleOp* op) {
|
||||
op->set_op_id(op_id);
|
||||
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
|
||||
TensorHandle* h = nullptr;
|
||||
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
|
||||
if (h->Type() == TensorHandle::LOCAL) {
|
||||
// AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
|
||||
// copy it to the CPU before copying it out.
|
||||
Tensor tensor;
|
||||
TF_RETURN_IF_ERROR(h->CopyToDevice(*ctx, ctx->HostCPU(), &tensor));
|
||||
auto* local_handle = op->add_handles()->mutable_local_handle();
|
||||
local_handle->set_device(h->op_device() ? h->op_device()->name()
|
||||
: ctx->HostCPU()->name());
|
||||
tensor.AsProtoTensorContent(local_handle->mutable_tensor());
|
||||
} else if (h->Type() == TensorHandle::REMOTE) {
|
||||
// Only serialize the resource dtype and shape of the first handle, since
|
||||
// all handles are of the same resource dtype and shape.
|
||||
Device* src_device = absl::get<Device*>(h->device());
|
||||
const bool serialize_resource_dtype_and_shape =
|
||||
(i == 0) && (h->dtype == DT_RESOURCE) &&
|
||||
(ctx->OnSameTask(src_device, target_device));
|
||||
TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
h, op->add_handles()->mutable_remote_handle(), src_device,
|
||||
absl::get<Device*>(h->DeviceOrHostCPU(*ctx))->name(),
|
||||
serialize_resource_dtype_and_shape));
|
||||
} else {
|
||||
return errors::InvalidArgument("Nested packed handles are not supported");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) {
|
||||
Status s;
|
||||
const uint64 context_view_id = ctx_->GetContextViewId();
|
||||
if (!send_device_->IsLocal()) {
|
||||
s = errors::InvalidArgument(
|
||||
"Copy a packed handle from a remote device is not supported");
|
||||
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
|
||||
EnqueueRequest request;
|
||||
uint64 context_id = ctx_->GetContextId();
|
||||
request.set_context_id(context_id);
|
||||
s = SerializePackedHandle(recv_op_id_, src_, recv_device_, ctx_,
|
||||
request.add_queue()->mutable_send_packed_handle());
|
||||
if (!s.ok()) {
|
||||
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
|
||||
TensorShape shape;
|
||||
s = src_->Shape(&shape);
|
||||
if (!s.ok()) {
|
||||
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
captured_state_->SetSrcShape(shape);
|
||||
|
||||
core::RefCountPtr<eager::EagerClient> eager_client;
|
||||
s = ctx_->GetClient(recv_device_, &eager_client);
|
||||
if (!s.ok()) {
|
||||
captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
|
||||
EnqueueResponse* response = new EnqueueResponse;
|
||||
Device* recv_device = recv_device_;
|
||||
const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
|
||||
eager_client->StreamingEnqueueAsync(
|
||||
&request, response,
|
||||
[captured_state, response, recv_device, context_view_id,
|
||||
done](const Status& s) {
|
||||
if (s.ok()) {
|
||||
Status status = captured_state->dst()->SetRemoteShape(
|
||||
captured_state->GetSrcShape(), recv_device, context_view_id);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Ignoring an error encountered when setting remote "
|
||||
"shape of tensor received by SendPackedHadnle rpc: "
|
||||
<< status.ToString();
|
||||
}
|
||||
} else {
|
||||
captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
|
||||
}
|
||||
done(s);
|
||||
delete response;
|
||||
});
|
||||
}
|
||||
|
||||
void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
|
||||
Status s;
|
||||
EnqueueRequest request;
|
||||
@ -351,7 +449,11 @@ Status RemoteCopyNode::Prepare() {
|
||||
|
||||
void RemoteCopyNode::RunAsync(StatusCallback done) {
|
||||
started_ = true;
|
||||
if (ctx_->UseSendTensorRPC() && send_device_->IsLocal() &&
|
||||
if (src_->Type() == TensorHandle::PACKED) {
|
||||
return StartSendPackedHandle(std::move(done));
|
||||
}
|
||||
|
||||
if ((ctx_->UseSendTensorRPC()) && send_device_->IsLocal() &&
|
||||
!recv_device_->IsLocal()) {
|
||||
return StartRemoteSendTensor(std::move(done));
|
||||
}
|
||||
|
@ -121,6 +121,9 @@ class RemoteCopyNode : public AsyncEagerNode {
|
||||
// SendTensor RPC *on the receiver*.
|
||||
void StartRemoteSendTensor(StatusCallback done);
|
||||
|
||||
// Send a local packed TensorHandle to a remote device.
|
||||
void StartSendPackedHandle(StatusCallback done);
|
||||
|
||||
// State that is captured by Send and/or Recv callbacks (depending on which
|
||||
// one(s) is remote) and outlives this node in the case of remote->remote
|
||||
// copy.
|
||||
|
@ -69,6 +69,7 @@ message QueueItem {
|
||||
// enqueued in streaming call. Request with this item type waits for pending
|
||||
// nodes to finish on the remote executor and report status.
|
||||
SyncRemoteExecutorForStream sync_remote_executor_for_stream = 6;
|
||||
SendPackedHandleOp send_packed_handle = 7;
|
||||
}
|
||||
}
|
||||
|
||||
@ -238,6 +239,27 @@ message SendTensorOp {
|
||||
string device_name = 3;
|
||||
}
|
||||
|
||||
// Send a packed TensorHandle to a remote worker.
|
||||
message SendPackedHandleOp {
|
||||
// Op id of the remote packed TensorHandle.
|
||||
int64 op_id = 1;
|
||||
|
||||
message LocalTensorHandle {
|
||||
TensorProto tensor = 1;
|
||||
// Device where the tensor is produced.
|
||||
string device = 2;
|
||||
}
|
||||
|
||||
message Handle {
|
||||
oneof item {
|
||||
LocalTensorHandle local_handle = 1;
|
||||
RemoteTensorHandle remote_handle = 2;
|
||||
}
|
||||
}
|
||||
|
||||
repeated Handle handles = 2;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Eager Service defines a TensorFlow service that executes operations eagerly
|
||||
|
Loading…
Reference in New Issue
Block a user