From c08edd7f929ae9ecb19c9116337597d94d249469 Mon Sep 17 00:00:00 2001 From: Yujing Zhang Date: Fri, 15 Nov 2019 16:17:28 -0800 Subject: [PATCH] For a multi-host function, cache dtypes and shapes of remote variable inputs on the default function device, when lazy tensor copy is enabled. With this change, the dtypes and shapes would only be serialized and sent to the default function device once. PiperOrigin-RevId: 280758012 Change-Id: I1560a9d171f627b0d20aae51dd5f35a3b4f2c437 --- .../core/common_runtime/eager/execute.cc | 15 +++- .../common_runtime/eager/tensor_handle.cc | 20 +++++ .../core/common_runtime/eager/tensor_handle.h | 10 +++ .../distributed_runtime/eager/remote_mgr.cc | 74 +++++++++++++------ .../distributed_runtime/eager/remote_mgr.h | 17 ++++- .../eager/remote_mgr_test.cc | 35 --------- 6 files changed, 109 insertions(+), 62 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index ad8e2a29277..c819f0f719f 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -780,12 +780,23 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, handle->Unref(); } } else { - serialize_resource_dtype_and_shape = input->dtype == DT_RESOURCE; + serialize_resource_dtype_and_shape = + (input->dtype == DT_RESOURCE) && + (!input->HasResourceShapeMirror(op->Device())); } } + auto* input_handle = remote_op->add_inputs(); TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle( - input, remote_op->add_inputs(), input_device, *input_device_name, + input, input_handle, input_device, *input_device_name, serialize_resource_dtype_and_shape)); + if (!input_handle->resource_dtypes_and_shapes().empty()) { + auto tensor_handle_data = + absl::make_unique( + input_handle->op_id(), input_handle->output_num(), remote_task, + context_id, ctx); + TF_RETURN_IF_ERROR(input->AddResourceShapeMirror( + std::move(tensor_handle_data), op->Device())); + } } } diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 9af9b2387a2..1be22c7c1f7 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -470,6 +470,15 @@ bool TensorHandle::HasRemoteMirror(Device* d) { return false; } +bool TensorHandle::HasResourceShapeMirror(Device* d) { + tf_shared_lock l(resource_shape_mirrors_mutex_); + auto mirror = resource_shape_mirrors_.find(d); + if (mirror != resource_shape_mirrors_.end()) { + return true; + } + return false; +} + Status TensorHandle::AddUnshapedRemoteMirror( std::unique_ptr t, Device* d) { mutex_lock l(remote_mirrors_mutex_); @@ -486,6 +495,17 @@ Status TensorHandle::AddUnshapedRemoteMirror( return Status::OK(); } +Status TensorHandle::AddResourceShapeMirror( + std::unique_ptr t, Device* d) { + mutex_lock l(resource_shape_mirrors_mutex_); + auto ret = resource_shape_mirrors_.insert(std::make_pair(d, std::move(t))); + if (!ret.second) { + return errors::Internal("Attempted to duplicate a resource shape mirror."); + } + + return Status::OK(); +} + Status TensorHandle::AddRemoteMirror(std::unique_ptr t, Device* d) { mutex_lock l(remote_mirrors_mutex_); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index c1c824fbe8a..d2be6a4555f 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -142,10 +142,13 @@ class TensorHandle : public core::RefCounted { #if !defined(IS_MOBILE_PLATFORM) bool HasRemoteMirror(Device* d); + bool HasResourceShapeMirror(Device* d); Status AddUnshapedRemoteMirror( std::unique_ptr t, Device* d); Status AddRemoteMirror(std::unique_ptr t, Device* d); + Status AddResourceShapeMirror( + std::unique_ptr t, Device* d); // Return the op_id and output num if the handle refers to a remote tensor. Status RemoteAddress(Device* d, int64* op_id, int32* output_num) const; @@ -245,6 +248,13 @@ class TensorHandle : public core::RefCounted { tensorflow::Device* const resource_device_; #if !defined(IS_MOBILE_PLATFORM) + // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica + // variable is ready, since we could get the shape locally without remote copy + // then. + mutable mutex resource_shape_mirrors_mutex_; + std::map> + resource_shape_mirrors_ GUARDED_BY(resource_shape_mirrors_mutex_); + mutable mutex remote_mirrors_mutex_; // TODO(gjn): Unshaped remote mirrors are long expected to be long-lived. // Consider replacing the unshaped_remote_mirrors_ map with something more diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index c1812914012..a77d0cf41d9 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -57,6 +57,22 @@ Status RemoteMgr::GetTensorHandle( return GetTensorHandleImpl(remote_handle, handle); } +Status RemoteMgr::GetMirroredResourceShape( + const RemoteTensorHandleInternal& remote_handle, + std::vector* handle) { + tf_shared_lock l(mirrored_resource_shape_mu_); + auto iter = mirrored_resource_shape_map_.find(remote_handle); + if (iter == mirrored_resource_shape_map_.end()) { + return errors::InvalidArgument( + "Unable to find the relevant mirrored resource shape: Op ID: ", + remote_handle.op_id, ", Output num: ", remote_handle.output_num); + } + + *handle = iter->second; + + return Status::OK(); +} + Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, int64* op_id, int32* output_num) { TF_RETURN_IF_ERROR( @@ -74,18 +90,26 @@ Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, Status RemoteMgr::DeleteTensorHandle( const RemoteTensorHandleInternal& remote_handle) { - mutex_lock l(remote_tensor_handle_mu_); - auto iter = remote_tensor_handle_map_.find(remote_handle); - if (iter == remote_tensor_handle_map_.end()) { - return errors::InvalidArgument( - "Unable to find the relevant tensor remote_handle: Op ID: ", - remote_handle.op_id, ", Output num: ", remote_handle.output_num); + { + mutex_lock l(remote_tensor_handle_mu_); + auto iter = remote_tensor_handle_map_.find(remote_handle); + if (iter != remote_tensor_handle_map_.end()) { + iter->second->Unref(); + remote_tensor_handle_map_.erase(iter); + return Status::OK(); + } } - - iter->second->Unref(); - remote_tensor_handle_map_.erase(iter); - - return Status::OK(); + { + mutex_lock l(mirrored_resource_shape_mu_); + auto iter = mirrored_resource_shape_map_.find(remote_handle); + if (iter != mirrored_resource_shape_map_.end()) { + mirrored_resource_shape_map_.erase(iter); + return Status::OK(); + } + } + return errors::InvalidArgument( + "Unable to find the relevant tensor remote_handle: Op ID: ", + remote_handle.op_id, ", Output num: ", remote_handle.output_num); } Status RemoteMgr::SerializeRemoteTensorHandle( @@ -94,15 +118,8 @@ Status RemoteMgr::SerializeRemoteTensorHandle( int64 op_id; int32 output_num; if (!in->RemoteAddress(device, &op_id, &output_num).ok()) { - mutex_lock l(remote_tensor_handle_mu_); - if (!GetRemoteTensorHandle(in, &op_id, &output_num).ok()) { - op_id = NextOpId(); - output_num = 0; - in->SetRemoteOpIdAndOutputNumToLocalTensorHandle(op_id, output_num); - in->Ref(); - remote_tensor_handle_map_.emplace( - RemoteTensorHandleInternal(op_id, output_num), in); - } + tf_shared_lock l(remote_tensor_handle_mu_); + TF_RETURN_IF_ERROR(GetRemoteTensorHandle(in, &op_id, &output_num)); } out->Clear(); out->set_op_id(op_id); @@ -150,10 +167,19 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle( std::move(remote_handle_data), in.dtype(), device, parent_, out)); std::vector dtypes_and_shapes; - for (const auto& dtype_and_shape_proto : in.resource_dtypes_and_shapes()) { - dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{ - dtype_and_shape_proto.dtype(), - TensorShape(dtype_and_shape_proto.shape())}); + if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in), + &dtypes_and_shapes) + .ok()) { + for (const auto& dtype_and_shape_proto : + in.resource_dtypes_and_shapes()) { + dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{ + dtype_and_shape_proto.dtype(), + TensorShape(dtype_and_shape_proto.shape())}); + } + mutex_lock l(mirrored_resource_shape_mu_); + mirrored_resource_shape_map_.emplace( + RemoteTensorHandleInternal(in.op_id(), in.output_num()), + dtypes_and_shapes); } (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes)); } diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.h b/tensorflow/core/distributed_runtime/eager/remote_mgr.h index 6853fbb40cd..4fd3b09fbbb 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.h +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.h @@ -58,7 +58,7 @@ class RemoteMgr { return next_op_id_++; } - // Serialize a TensorHandle(local/remote) to a RemoteTensorHandle. + // Serialize a remote TensorHandle to a RemoteTensorHandle. Status SerializeRemoteTensorHandle( TensorHandle* in, RemoteTensorHandle* out, Device* device, const string& device_name, @@ -88,12 +88,20 @@ class RemoteMgr { tensorflow::TensorHandle** handle) SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_); + Status GetMirroredResourceShape( + const RemoteTensorHandleInternal& remote_handle, + std::vector* handle); + bool is_master_; using RemoteTensorHandleMap = gtl::FlatMap; + using MirroredResourceShapeMap = gtl::FlatMap< + RemoteTensorHandleInternal, std::vector, + RemoteTensorHandleInternalHash, RemoteTensorHandleInternalEquals>; + mutex remote_tensor_handle_mu_; // This map maintains the TensorHandles that are required by remote workers // in the cluster. Each map key is generated by the master, so it should be @@ -101,6 +109,13 @@ class RemoteMgr { RemoteTensorHandleMap remote_tensor_handle_map_ GUARDED_BY(remote_tensor_handle_mu_); + mutex mirrored_resource_shape_mu_; + // This map maintains the data types and shapes of resource variables required + // by remote workers in the cluster. Each map key is generated by the master, + // so it should be globally unique. + MirroredResourceShapeMap mirrored_resource_shape_map_ + GUARDED_BY(mirrored_resource_shape_mu_); + EagerContext* parent_; // not owned. mutex executor_map_mu_; diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc index 6bb4943ffee..312637b9965 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc @@ -68,41 +68,6 @@ class RemoteMgrTest : public ::testing::Test { EagerContext* ctx_; }; -TEST_F(RemoteMgrTest, LocalTensorHandle) { - TestRemoteMgr remote_mgr(true, ctx_); - Tensor t(DT_FLOAT, TensorShape({0})); - - TensorHandle* handle; - TF_ASSERT_OK(TensorHandle::CreateLocalHandle(t, &handle)); - EXPECT_EQ(nullptr, handle->device()); - EXPECT_EQ(local_device_, handle->DeviceOrHostCPU(ctx_)); - const uint64 op_id = remote_mgr.OpId(); - EXPECT_EQ(1, op_id); - RemoteTensorHandle remote_handle; - TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle( - handle, &remote_handle, handle->device(), - handle->DeviceOrHostCPU(ctx_)->name())); - EXPECT_EQ(2, remote_mgr.OpId()); - EXPECT_EQ(op_id, remote_handle.op_id()); - EXPECT_EQ(0, remote_handle.output_num()); - EXPECT_EQ(local_device_->name(), remote_handle.device()); - - TensorHandle* deserialized_handle; - TF_ASSERT_OK(remote_mgr.DeserializeRemoteTensorHandle(remote_handle, - &deserialized_handle)); - tensorflow::TensorHandle* h; - TF_EXPECT_OK(remote_mgr.GetTensorHandle( - RemoteTensorHandleInternal(remote_handle), &h)); - TF_ASSERT_OK( - remote_mgr.DeleteTensorHandle(RemoteTensorHandleInternal(remote_handle))); - EXPECT_FALSE( - remote_mgr.GetTensorHandle(RemoteTensorHandleInternal(remote_handle), &h) - .ok()); - - deserialized_handle->Unref(); - handle->Unref(); -} - TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) { RemoteMgr remote_mgr(false, ctx_); Tensor t(DT_FLOAT, TensorShape({0}));