For a multi-host function, cache dtypes and shapes of remote variable inputs on the default function device, when lazy tensor copy is enabled.

With this change, the dtypes and shapes would only be serialized and sent to the default function device once.

PiperOrigin-RevId: 280758012
Change-Id: I1560a9d171f627b0d20aae51dd5f35a3b4f2c437
This commit is contained in:
Yujing Zhang 2019-11-15 16:17:28 -08:00 committed by TensorFlower Gardener
parent fa0bfeb53e
commit c08edd7f92
6 changed files with 109 additions and 62 deletions

View File

@ -780,12 +780,23 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
handle->Unref();
}
} else {
serialize_resource_dtype_and_shape = input->dtype == DT_RESOURCE;
serialize_resource_dtype_and_shape =
(input->dtype == DT_RESOURCE) &&
(!input->HasResourceShapeMirror(op->Device()));
}
}
auto* input_handle = remote_op->add_inputs();
TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
input, remote_op->add_inputs(), input_device, *input_device_name,
input, input_handle, input_device, *input_device_name,
serialize_resource_dtype_and_shape));
if (!input_handle->resource_dtypes_and_shapes().empty()) {
auto tensor_handle_data =
absl::make_unique<UnshapedRemoteTensorHandleData>(
input_handle->op_id(), input_handle->output_num(), remote_task,
context_id, ctx);
TF_RETURN_IF_ERROR(input->AddResourceShapeMirror(
std::move(tensor_handle_data), op->Device()));
}
}
}

View File

@ -470,6 +470,15 @@ bool TensorHandle::HasRemoteMirror(Device* d) {
return false;
}
bool TensorHandle::HasResourceShapeMirror(Device* d) {
tf_shared_lock l(resource_shape_mirrors_mutex_);
auto mirror = resource_shape_mirrors_.find(d);
if (mirror != resource_shape_mirrors_.end()) {
return true;
}
return false;
}
Status TensorHandle::AddUnshapedRemoteMirror(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d) {
mutex_lock l(remote_mirrors_mutex_);
@ -486,6 +495,17 @@ Status TensorHandle::AddUnshapedRemoteMirror(
return Status::OK();
}
Status TensorHandle::AddResourceShapeMirror(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d) {
mutex_lock l(resource_shape_mirrors_mutex_);
auto ret = resource_shape_mirrors_.insert(std::make_pair(d, std::move(t)));
if (!ret.second) {
return errors::Internal("Attempted to duplicate a resource shape mirror.");
}
return Status::OK();
}
Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
Device* d) {
mutex_lock l(remote_mirrors_mutex_);

View File

@ -142,10 +142,13 @@ class TensorHandle : public core::RefCounted {
#if !defined(IS_MOBILE_PLATFORM)
bool HasRemoteMirror(Device* d);
bool HasResourceShapeMirror(Device* d);
Status AddUnshapedRemoteMirror(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d);
Status AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t, Device* d);
Status AddResourceShapeMirror(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d);
// Return the op_id and output num if the handle refers to a remote tensor.
Status RemoteAddress(Device* d, int64* op_id, int32* output_num) const;
@ -245,6 +248,13 @@ class TensorHandle : public core::RefCounted {
tensorflow::Device* const resource_device_;
#if !defined(IS_MOBILE_PLATFORM)
// TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
// variable is ready, since we could get the shape locally without remote copy
// then.
mutable mutex resource_shape_mirrors_mutex_;
std::map<tensorflow::Device*, std::unique_ptr<UnshapedRemoteTensorHandleData>>
resource_shape_mirrors_ GUARDED_BY(resource_shape_mirrors_mutex_);
mutable mutex remote_mirrors_mutex_;
// TODO(gjn): Unshaped remote mirrors are long expected to be long-lived.
// Consider replacing the unshaped_remote_mirrors_ map with something more

View File

@ -57,6 +57,22 @@ Status RemoteMgr::GetTensorHandle(
return GetTensorHandleImpl(remote_handle, handle);
}
Status RemoteMgr::GetMirroredResourceShape(
const RemoteTensorHandleInternal& remote_handle,
std::vector<DtypeAndPartialTensorShape>* handle) {
tf_shared_lock l(mirrored_resource_shape_mu_);
auto iter = mirrored_resource_shape_map_.find(remote_handle);
if (iter == mirrored_resource_shape_map_.end()) {
return errors::InvalidArgument(
"Unable to find the relevant mirrored resource shape: Op ID: ",
remote_handle.op_id, ", Output num: ", remote_handle.output_num);
}
*handle = iter->second;
return Status::OK();
}
Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
int64* op_id, int32* output_num) {
TF_RETURN_IF_ERROR(
@ -74,18 +90,26 @@ Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
Status RemoteMgr::DeleteTensorHandle(
const RemoteTensorHandleInternal& remote_handle) {
mutex_lock l(remote_tensor_handle_mu_);
auto iter = remote_tensor_handle_map_.find(remote_handle);
if (iter == remote_tensor_handle_map_.end()) {
return errors::InvalidArgument(
"Unable to find the relevant tensor remote_handle: Op ID: ",
remote_handle.op_id, ", Output num: ", remote_handle.output_num);
{
mutex_lock l(remote_tensor_handle_mu_);
auto iter = remote_tensor_handle_map_.find(remote_handle);
if (iter != remote_tensor_handle_map_.end()) {
iter->second->Unref();
remote_tensor_handle_map_.erase(iter);
return Status::OK();
}
}
iter->second->Unref();
remote_tensor_handle_map_.erase(iter);
return Status::OK();
{
mutex_lock l(mirrored_resource_shape_mu_);
auto iter = mirrored_resource_shape_map_.find(remote_handle);
if (iter != mirrored_resource_shape_map_.end()) {
mirrored_resource_shape_map_.erase(iter);
return Status::OK();
}
}
return errors::InvalidArgument(
"Unable to find the relevant tensor remote_handle: Op ID: ",
remote_handle.op_id, ", Output num: ", remote_handle.output_num);
}
Status RemoteMgr::SerializeRemoteTensorHandle(
@ -94,15 +118,8 @@ Status RemoteMgr::SerializeRemoteTensorHandle(
int64 op_id;
int32 output_num;
if (!in->RemoteAddress(device, &op_id, &output_num).ok()) {
mutex_lock l(remote_tensor_handle_mu_);
if (!GetRemoteTensorHandle(in, &op_id, &output_num).ok()) {
op_id = NextOpId();
output_num = 0;
in->SetRemoteOpIdAndOutputNumToLocalTensorHandle(op_id, output_num);
in->Ref();
remote_tensor_handle_map_.emplace(
RemoteTensorHandleInternal(op_id, output_num), in);
}
tf_shared_lock l(remote_tensor_handle_mu_);
TF_RETURN_IF_ERROR(GetRemoteTensorHandle(in, &op_id, &output_num));
}
out->Clear();
out->set_op_id(op_id);
@ -150,10 +167,19 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
std::move(remote_handle_data), in.dtype(), device, parent_, out));
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
for (const auto& dtype_and_shape_proto : in.resource_dtypes_and_shapes()) {
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
dtype_and_shape_proto.dtype(),
TensorShape(dtype_and_shape_proto.shape())});
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
&dtypes_and_shapes)
.ok()) {
for (const auto& dtype_and_shape_proto :
in.resource_dtypes_and_shapes()) {
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
dtype_and_shape_proto.dtype(),
TensorShape(dtype_and_shape_proto.shape())});
}
mutex_lock l(mirrored_resource_shape_mu_);
mirrored_resource_shape_map_.emplace(
RemoteTensorHandleInternal(in.op_id(), in.output_num()),
dtypes_and_shapes);
}
(*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
}

View File

@ -58,7 +58,7 @@ class RemoteMgr {
return next_op_id_++;
}
// Serialize a TensorHandle(local/remote) to a RemoteTensorHandle.
// Serialize a remote TensorHandle to a RemoteTensorHandle.
Status SerializeRemoteTensorHandle(
TensorHandle* in, RemoteTensorHandle* out, Device* device,
const string& device_name,
@ -88,12 +88,20 @@ class RemoteMgr {
tensorflow::TensorHandle** handle)
SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_);
Status GetMirroredResourceShape(
const RemoteTensorHandleInternal& remote_handle,
std::vector<DtypeAndPartialTensorShape>* handle);
bool is_master_;
using RemoteTensorHandleMap =
gtl::FlatMap<RemoteTensorHandleInternal, tensorflow::TensorHandle*,
RemoteTensorHandleInternalHash,
RemoteTensorHandleInternalEquals>;
using MirroredResourceShapeMap = gtl::FlatMap<
RemoteTensorHandleInternal, std::vector<DtypeAndPartialTensorShape>,
RemoteTensorHandleInternalHash, RemoteTensorHandleInternalEquals>;
mutex remote_tensor_handle_mu_;
// This map maintains the TensorHandles that are required by remote workers
// in the cluster. Each map key is generated by the master, so it should be
@ -101,6 +109,13 @@ class RemoteMgr {
RemoteTensorHandleMap remote_tensor_handle_map_
GUARDED_BY(remote_tensor_handle_mu_);
mutex mirrored_resource_shape_mu_;
// This map maintains the data types and shapes of resource variables required
// by remote workers in the cluster. Each map key is generated by the master,
// so it should be globally unique.
MirroredResourceShapeMap mirrored_resource_shape_map_
GUARDED_BY(mirrored_resource_shape_mu_);
EagerContext* parent_; // not owned.
mutex executor_map_mu_;

View File

@ -68,41 +68,6 @@ class RemoteMgrTest : public ::testing::Test {
EagerContext* ctx_;
};
TEST_F(RemoteMgrTest, LocalTensorHandle) {
TestRemoteMgr remote_mgr(true, ctx_);
Tensor t(DT_FLOAT, TensorShape({0}));
TensorHandle* handle;
TF_ASSERT_OK(TensorHandle::CreateLocalHandle(t, &handle));
EXPECT_EQ(nullptr, handle->device());
EXPECT_EQ(local_device_, handle->DeviceOrHostCPU(ctx_));
const uint64 op_id = remote_mgr.OpId();
EXPECT_EQ(1, op_id);
RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, &remote_handle, handle->device(),
handle->DeviceOrHostCPU(ctx_)->name()));
EXPECT_EQ(2, remote_mgr.OpId());
EXPECT_EQ(op_id, remote_handle.op_id());
EXPECT_EQ(0, remote_handle.output_num());
EXPECT_EQ(local_device_->name(), remote_handle.device());
TensorHandle* deserialized_handle;
TF_ASSERT_OK(remote_mgr.DeserializeRemoteTensorHandle(remote_handle,
&deserialized_handle));
tensorflow::TensorHandle* h;
TF_EXPECT_OK(remote_mgr.GetTensorHandle(
RemoteTensorHandleInternal(remote_handle), &h));
TF_ASSERT_OK(
remote_mgr.DeleteTensorHandle(RemoteTensorHandleInternal(remote_handle)));
EXPECT_FALSE(
remote_mgr.GetTensorHandle(RemoteTensorHandleInternal(remote_handle), &h)
.ok());
deserialized_handle->Unref();
handle->Unref();
}
TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
RemoteMgr remote_mgr(false, ctx_);
Tensor t(DT_FLOAT, TensorShape({0}));