diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index ad8e2a29277..c819f0f719f 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -780,12 +780,23 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, handle->Unref(); } } else { - serialize_resource_dtype_and_shape = input->dtype == DT_RESOURCE; + serialize_resource_dtype_and_shape = + (input->dtype == DT_RESOURCE) && + (!input->HasResourceShapeMirror(op->Device())); } } + auto* input_handle = remote_op->add_inputs(); TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle( - input, remote_op->add_inputs(), input_device, *input_device_name, + input, input_handle, input_device, *input_device_name, serialize_resource_dtype_and_shape)); + if (!input_handle->resource_dtypes_and_shapes().empty()) { + auto tensor_handle_data = + absl::make_unique( + input_handle->op_id(), input_handle->output_num(), remote_task, + context_id, ctx); + TF_RETURN_IF_ERROR(input->AddResourceShapeMirror( + std::move(tensor_handle_data), op->Device())); + } } } diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 9af9b2387a2..1be22c7c1f7 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -470,6 +470,15 @@ bool TensorHandle::HasRemoteMirror(Device* d) { return false; } +bool TensorHandle::HasResourceShapeMirror(Device* d) { + tf_shared_lock l(resource_shape_mirrors_mutex_); + auto mirror = resource_shape_mirrors_.find(d); + if (mirror != resource_shape_mirrors_.end()) { + return true; + } + return false; +} + Status TensorHandle::AddUnshapedRemoteMirror( std::unique_ptr t, Device* d) { mutex_lock l(remote_mirrors_mutex_); @@ -486,6 +495,17 @@ Status TensorHandle::AddUnshapedRemoteMirror( return Status::OK(); } +Status TensorHandle::AddResourceShapeMirror( + std::unique_ptr t, Device* d) { + mutex_lock l(resource_shape_mirrors_mutex_); + auto ret = resource_shape_mirrors_.insert(std::make_pair(d, std::move(t))); + if (!ret.second) { + return errors::Internal("Attempted to duplicate a resource shape mirror."); + } + + return Status::OK(); +} + Status TensorHandle::AddRemoteMirror(std::unique_ptr t, Device* d) { mutex_lock l(remote_mirrors_mutex_); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index c1c824fbe8a..d2be6a4555f 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -142,10 +142,13 @@ class TensorHandle : public core::RefCounted { #if !defined(IS_MOBILE_PLATFORM) bool HasRemoteMirror(Device* d); + bool HasResourceShapeMirror(Device* d); Status AddUnshapedRemoteMirror( std::unique_ptr t, Device* d); Status AddRemoteMirror(std::unique_ptr t, Device* d); + Status AddResourceShapeMirror( + std::unique_ptr t, Device* d); // Return the op_id and output num if the handle refers to a remote tensor. Status RemoteAddress(Device* d, int64* op_id, int32* output_num) const; @@ -245,6 +248,13 @@ class TensorHandle : public core::RefCounted { tensorflow::Device* const resource_device_; #if !defined(IS_MOBILE_PLATFORM) + // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica + // variable is ready, since we could get the shape locally without remote copy + // then. + mutable mutex resource_shape_mirrors_mutex_; + std::map> + resource_shape_mirrors_ GUARDED_BY(resource_shape_mirrors_mutex_); + mutable mutex remote_mirrors_mutex_; // TODO(gjn): Unshaped remote mirrors are long expected to be long-lived. // Consider replacing the unshaped_remote_mirrors_ map with something more diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index c1812914012..a77d0cf41d9 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -57,6 +57,22 @@ Status RemoteMgr::GetTensorHandle( return GetTensorHandleImpl(remote_handle, handle); } +Status RemoteMgr::GetMirroredResourceShape( + const RemoteTensorHandleInternal& remote_handle, + std::vector* handle) { + tf_shared_lock l(mirrored_resource_shape_mu_); + auto iter = mirrored_resource_shape_map_.find(remote_handle); + if (iter == mirrored_resource_shape_map_.end()) { + return errors::InvalidArgument( + "Unable to find the relevant mirrored resource shape: Op ID: ", + remote_handle.op_id, ", Output num: ", remote_handle.output_num); + } + + *handle = iter->second; + + return Status::OK(); +} + Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, int64* op_id, int32* output_num) { TF_RETURN_IF_ERROR( @@ -74,18 +90,26 @@ Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, Status RemoteMgr::DeleteTensorHandle( const RemoteTensorHandleInternal& remote_handle) { - mutex_lock l(remote_tensor_handle_mu_); - auto iter = remote_tensor_handle_map_.find(remote_handle); - if (iter == remote_tensor_handle_map_.end()) { - return errors::InvalidArgument( - "Unable to find the relevant tensor remote_handle: Op ID: ", - remote_handle.op_id, ", Output num: ", remote_handle.output_num); + { + mutex_lock l(remote_tensor_handle_mu_); + auto iter = remote_tensor_handle_map_.find(remote_handle); + if (iter != remote_tensor_handle_map_.end()) { + iter->second->Unref(); + remote_tensor_handle_map_.erase(iter); + return Status::OK(); + } } - - iter->second->Unref(); - remote_tensor_handle_map_.erase(iter); - - return Status::OK(); + { + mutex_lock l(mirrored_resource_shape_mu_); + auto iter = mirrored_resource_shape_map_.find(remote_handle); + if (iter != mirrored_resource_shape_map_.end()) { + mirrored_resource_shape_map_.erase(iter); + return Status::OK(); + } + } + return errors::InvalidArgument( + "Unable to find the relevant tensor remote_handle: Op ID: ", + remote_handle.op_id, ", Output num: ", remote_handle.output_num); } Status RemoteMgr::SerializeRemoteTensorHandle( @@ -94,15 +118,8 @@ Status RemoteMgr::SerializeRemoteTensorHandle( int64 op_id; int32 output_num; if (!in->RemoteAddress(device, &op_id, &output_num).ok()) { - mutex_lock l(remote_tensor_handle_mu_); - if (!GetRemoteTensorHandle(in, &op_id, &output_num).ok()) { - op_id = NextOpId(); - output_num = 0; - in->SetRemoteOpIdAndOutputNumToLocalTensorHandle(op_id, output_num); - in->Ref(); - remote_tensor_handle_map_.emplace( - RemoteTensorHandleInternal(op_id, output_num), in); - } + tf_shared_lock l(remote_tensor_handle_mu_); + TF_RETURN_IF_ERROR(GetRemoteTensorHandle(in, &op_id, &output_num)); } out->Clear(); out->set_op_id(op_id); @@ -150,10 +167,19 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle( std::move(remote_handle_data), in.dtype(), device, parent_, out)); std::vector dtypes_and_shapes; - for (const auto& dtype_and_shape_proto : in.resource_dtypes_and_shapes()) { - dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{ - dtype_and_shape_proto.dtype(), - TensorShape(dtype_and_shape_proto.shape())}); + if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in), + &dtypes_and_shapes) + .ok()) { + for (const auto& dtype_and_shape_proto : + in.resource_dtypes_and_shapes()) { + dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{ + dtype_and_shape_proto.dtype(), + TensorShape(dtype_and_shape_proto.shape())}); + } + mutex_lock l(mirrored_resource_shape_mu_); + mirrored_resource_shape_map_.emplace( + RemoteTensorHandleInternal(in.op_id(), in.output_num()), + dtypes_and_shapes); } (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes)); } diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.h b/tensorflow/core/distributed_runtime/eager/remote_mgr.h index 6853fbb40cd..4fd3b09fbbb 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.h +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.h @@ -58,7 +58,7 @@ class RemoteMgr { return next_op_id_++; } - // Serialize a TensorHandle(local/remote) to a RemoteTensorHandle. + // Serialize a remote TensorHandle to a RemoteTensorHandle. Status SerializeRemoteTensorHandle( TensorHandle* in, RemoteTensorHandle* out, Device* device, const string& device_name, @@ -88,12 +88,20 @@ class RemoteMgr { tensorflow::TensorHandle** handle) SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_); + Status GetMirroredResourceShape( + const RemoteTensorHandleInternal& remote_handle, + std::vector* handle); + bool is_master_; using RemoteTensorHandleMap = gtl::FlatMap; + using MirroredResourceShapeMap = gtl::FlatMap< + RemoteTensorHandleInternal, std::vector, + RemoteTensorHandleInternalHash, RemoteTensorHandleInternalEquals>; + mutex remote_tensor_handle_mu_; // This map maintains the TensorHandles that are required by remote workers // in the cluster. Each map key is generated by the master, so it should be @@ -101,6 +109,13 @@ class RemoteMgr { RemoteTensorHandleMap remote_tensor_handle_map_ GUARDED_BY(remote_tensor_handle_mu_); + mutex mirrored_resource_shape_mu_; + // This map maintains the data types and shapes of resource variables required + // by remote workers in the cluster. Each map key is generated by the master, + // so it should be globally unique. + MirroredResourceShapeMap mirrored_resource_shape_map_ + GUARDED_BY(mirrored_resource_shape_mu_); + EagerContext* parent_; // not owned. mutex executor_map_mu_; diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc index 6bb4943ffee..312637b9965 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc @@ -68,41 +68,6 @@ class RemoteMgrTest : public ::testing::Test { EagerContext* ctx_; }; -TEST_F(RemoteMgrTest, LocalTensorHandle) { - TestRemoteMgr remote_mgr(true, ctx_); - Tensor t(DT_FLOAT, TensorShape({0})); - - TensorHandle* handle; - TF_ASSERT_OK(TensorHandle::CreateLocalHandle(t, &handle)); - EXPECT_EQ(nullptr, handle->device()); - EXPECT_EQ(local_device_, handle->DeviceOrHostCPU(ctx_)); - const uint64 op_id = remote_mgr.OpId(); - EXPECT_EQ(1, op_id); - RemoteTensorHandle remote_handle; - TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle( - handle, &remote_handle, handle->device(), - handle->DeviceOrHostCPU(ctx_)->name())); - EXPECT_EQ(2, remote_mgr.OpId()); - EXPECT_EQ(op_id, remote_handle.op_id()); - EXPECT_EQ(0, remote_handle.output_num()); - EXPECT_EQ(local_device_->name(), remote_handle.device()); - - TensorHandle* deserialized_handle; - TF_ASSERT_OK(remote_mgr.DeserializeRemoteTensorHandle(remote_handle, - &deserialized_handle)); - tensorflow::TensorHandle* h; - TF_EXPECT_OK(remote_mgr.GetTensorHandle( - RemoteTensorHandleInternal(remote_handle), &h)); - TF_ASSERT_OK( - remote_mgr.DeleteTensorHandle(RemoteTensorHandleInternal(remote_handle))); - EXPECT_FALSE( - remote_mgr.GetTensorHandle(RemoteTensorHandleInternal(remote_handle), &h) - .ok()); - - deserialized_handle->Unref(); - handle->Unref(); -} - TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) { RemoteMgr remote_mgr(false, ctx_); Tensor t(DT_FLOAT, TensorShape({0}));