For a multi-host function, cache dtypes and shapes of remote variable inputs on the default function device, when lazy tensor copy is enabled.
With this change, the dtypes and shapes would only be serialized and sent to the default function device once. PiperOrigin-RevId: 280758012 Change-Id: I1560a9d171f627b0d20aae51dd5f35a3b4f2c437
This commit is contained in:
parent
fa0bfeb53e
commit
c08edd7f92
@ -780,12 +780,23 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
handle->Unref();
|
||||
}
|
||||
} else {
|
||||
serialize_resource_dtype_and_shape = input->dtype == DT_RESOURCE;
|
||||
serialize_resource_dtype_and_shape =
|
||||
(input->dtype == DT_RESOURCE) &&
|
||||
(!input->HasResourceShapeMirror(op->Device()));
|
||||
}
|
||||
}
|
||||
auto* input_handle = remote_op->add_inputs();
|
||||
TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
input, remote_op->add_inputs(), input_device, *input_device_name,
|
||||
input, input_handle, input_device, *input_device_name,
|
||||
serialize_resource_dtype_and_shape));
|
||||
if (!input_handle->resource_dtypes_and_shapes().empty()) {
|
||||
auto tensor_handle_data =
|
||||
absl::make_unique<UnshapedRemoteTensorHandleData>(
|
||||
input_handle->op_id(), input_handle->output_num(), remote_task,
|
||||
context_id, ctx);
|
||||
TF_RETURN_IF_ERROR(input->AddResourceShapeMirror(
|
||||
std::move(tensor_handle_data), op->Device()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -470,6 +470,15 @@ bool TensorHandle::HasRemoteMirror(Device* d) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TensorHandle::HasResourceShapeMirror(Device* d) {
|
||||
tf_shared_lock l(resource_shape_mirrors_mutex_);
|
||||
auto mirror = resource_shape_mirrors_.find(d);
|
||||
if (mirror != resource_shape_mirrors_.end()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status TensorHandle::AddUnshapedRemoteMirror(
|
||||
std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d) {
|
||||
mutex_lock l(remote_mirrors_mutex_);
|
||||
@ -486,6 +495,17 @@ Status TensorHandle::AddUnshapedRemoteMirror(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::AddResourceShapeMirror(
|
||||
std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d) {
|
||||
mutex_lock l(resource_shape_mirrors_mutex_);
|
||||
auto ret = resource_shape_mirrors_.insert(std::make_pair(d, std::move(t)));
|
||||
if (!ret.second) {
|
||||
return errors::Internal("Attempted to duplicate a resource shape mirror.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
|
||||
Device* d) {
|
||||
mutex_lock l(remote_mirrors_mutex_);
|
||||
|
@ -142,10 +142,13 @@ class TensorHandle : public core::RefCounted {
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
bool HasRemoteMirror(Device* d);
|
||||
bool HasResourceShapeMirror(Device* d);
|
||||
|
||||
Status AddUnshapedRemoteMirror(
|
||||
std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d);
|
||||
Status AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t, Device* d);
|
||||
Status AddResourceShapeMirror(
|
||||
std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d);
|
||||
|
||||
// Return the op_id and output num if the handle refers to a remote tensor.
|
||||
Status RemoteAddress(Device* d, int64* op_id, int32* output_num) const;
|
||||
@ -245,6 +248,13 @@ class TensorHandle : public core::RefCounted {
|
||||
tensorflow::Device* const resource_device_;
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
|
||||
// variable is ready, since we could get the shape locally without remote copy
|
||||
// then.
|
||||
mutable mutex resource_shape_mirrors_mutex_;
|
||||
std::map<tensorflow::Device*, std::unique_ptr<UnshapedRemoteTensorHandleData>>
|
||||
resource_shape_mirrors_ GUARDED_BY(resource_shape_mirrors_mutex_);
|
||||
|
||||
mutable mutex remote_mirrors_mutex_;
|
||||
// TODO(gjn): Unshaped remote mirrors are long expected to be long-lived.
|
||||
// Consider replacing the unshaped_remote_mirrors_ map with something more
|
||||
|
@ -57,6 +57,22 @@ Status RemoteMgr::GetTensorHandle(
|
||||
return GetTensorHandleImpl(remote_handle, handle);
|
||||
}
|
||||
|
||||
Status RemoteMgr::GetMirroredResourceShape(
|
||||
const RemoteTensorHandleInternal& remote_handle,
|
||||
std::vector<DtypeAndPartialTensorShape>* handle) {
|
||||
tf_shared_lock l(mirrored_resource_shape_mu_);
|
||||
auto iter = mirrored_resource_shape_map_.find(remote_handle);
|
||||
if (iter == mirrored_resource_shape_map_.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Unable to find the relevant mirrored resource shape: Op ID: ",
|
||||
remote_handle.op_id, ", Output num: ", remote_handle.output_num);
|
||||
}
|
||||
|
||||
*handle = iter->second;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
|
||||
int64* op_id, int32* output_num) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -74,18 +90,26 @@ Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
|
||||
|
||||
Status RemoteMgr::DeleteTensorHandle(
|
||||
const RemoteTensorHandleInternal& remote_handle) {
|
||||
mutex_lock l(remote_tensor_handle_mu_);
|
||||
auto iter = remote_tensor_handle_map_.find(remote_handle);
|
||||
if (iter == remote_tensor_handle_map_.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Unable to find the relevant tensor remote_handle: Op ID: ",
|
||||
remote_handle.op_id, ", Output num: ", remote_handle.output_num);
|
||||
{
|
||||
mutex_lock l(remote_tensor_handle_mu_);
|
||||
auto iter = remote_tensor_handle_map_.find(remote_handle);
|
||||
if (iter != remote_tensor_handle_map_.end()) {
|
||||
iter->second->Unref();
|
||||
remote_tensor_handle_map_.erase(iter);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
iter->second->Unref();
|
||||
remote_tensor_handle_map_.erase(iter);
|
||||
|
||||
return Status::OK();
|
||||
{
|
||||
mutex_lock l(mirrored_resource_shape_mu_);
|
||||
auto iter = mirrored_resource_shape_map_.find(remote_handle);
|
||||
if (iter != mirrored_resource_shape_map_.end()) {
|
||||
mirrored_resource_shape_map_.erase(iter);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return errors::InvalidArgument(
|
||||
"Unable to find the relevant tensor remote_handle: Op ID: ",
|
||||
remote_handle.op_id, ", Output num: ", remote_handle.output_num);
|
||||
}
|
||||
|
||||
Status RemoteMgr::SerializeRemoteTensorHandle(
|
||||
@ -94,15 +118,8 @@ Status RemoteMgr::SerializeRemoteTensorHandle(
|
||||
int64 op_id;
|
||||
int32 output_num;
|
||||
if (!in->RemoteAddress(device, &op_id, &output_num).ok()) {
|
||||
mutex_lock l(remote_tensor_handle_mu_);
|
||||
if (!GetRemoteTensorHandle(in, &op_id, &output_num).ok()) {
|
||||
op_id = NextOpId();
|
||||
output_num = 0;
|
||||
in->SetRemoteOpIdAndOutputNumToLocalTensorHandle(op_id, output_num);
|
||||
in->Ref();
|
||||
remote_tensor_handle_map_.emplace(
|
||||
RemoteTensorHandleInternal(op_id, output_num), in);
|
||||
}
|
||||
tf_shared_lock l(remote_tensor_handle_mu_);
|
||||
TF_RETURN_IF_ERROR(GetRemoteTensorHandle(in, &op_id, &output_num));
|
||||
}
|
||||
out->Clear();
|
||||
out->set_op_id(op_id);
|
||||
@ -150,10 +167,19 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
|
||||
std::move(remote_handle_data), in.dtype(), device, parent_, out));
|
||||
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
|
||||
for (const auto& dtype_and_shape_proto : in.resource_dtypes_and_shapes()) {
|
||||
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
|
||||
dtype_and_shape_proto.dtype(),
|
||||
TensorShape(dtype_and_shape_proto.shape())});
|
||||
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
|
||||
&dtypes_and_shapes)
|
||||
.ok()) {
|
||||
for (const auto& dtype_and_shape_proto :
|
||||
in.resource_dtypes_and_shapes()) {
|
||||
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
|
||||
dtype_and_shape_proto.dtype(),
|
||||
TensorShape(dtype_and_shape_proto.shape())});
|
||||
}
|
||||
mutex_lock l(mirrored_resource_shape_mu_);
|
||||
mirrored_resource_shape_map_.emplace(
|
||||
RemoteTensorHandleInternal(in.op_id(), in.output_num()),
|
||||
dtypes_and_shapes);
|
||||
}
|
||||
(*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
|
||||
}
|
||||
|
@ -58,7 +58,7 @@ class RemoteMgr {
|
||||
return next_op_id_++;
|
||||
}
|
||||
|
||||
// Serialize a TensorHandle(local/remote) to a RemoteTensorHandle.
|
||||
// Serialize a remote TensorHandle to a RemoteTensorHandle.
|
||||
Status SerializeRemoteTensorHandle(
|
||||
TensorHandle* in, RemoteTensorHandle* out, Device* device,
|
||||
const string& device_name,
|
||||
@ -88,12 +88,20 @@ class RemoteMgr {
|
||||
tensorflow::TensorHandle** handle)
|
||||
SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_);
|
||||
|
||||
Status GetMirroredResourceShape(
|
||||
const RemoteTensorHandleInternal& remote_handle,
|
||||
std::vector<DtypeAndPartialTensorShape>* handle);
|
||||
|
||||
bool is_master_;
|
||||
|
||||
using RemoteTensorHandleMap =
|
||||
gtl::FlatMap<RemoteTensorHandleInternal, tensorflow::TensorHandle*,
|
||||
RemoteTensorHandleInternalHash,
|
||||
RemoteTensorHandleInternalEquals>;
|
||||
using MirroredResourceShapeMap = gtl::FlatMap<
|
||||
RemoteTensorHandleInternal, std::vector<DtypeAndPartialTensorShape>,
|
||||
RemoteTensorHandleInternalHash, RemoteTensorHandleInternalEquals>;
|
||||
|
||||
mutex remote_tensor_handle_mu_;
|
||||
// This map maintains the TensorHandles that are required by remote workers
|
||||
// in the cluster. Each map key is generated by the master, so it should be
|
||||
@ -101,6 +109,13 @@ class RemoteMgr {
|
||||
RemoteTensorHandleMap remote_tensor_handle_map_
|
||||
GUARDED_BY(remote_tensor_handle_mu_);
|
||||
|
||||
mutex mirrored_resource_shape_mu_;
|
||||
// This map maintains the data types and shapes of resource variables required
|
||||
// by remote workers in the cluster. Each map key is generated by the master,
|
||||
// so it should be globally unique.
|
||||
MirroredResourceShapeMap mirrored_resource_shape_map_
|
||||
GUARDED_BY(mirrored_resource_shape_mu_);
|
||||
|
||||
EagerContext* parent_; // not owned.
|
||||
|
||||
mutex executor_map_mu_;
|
||||
|
@ -68,41 +68,6 @@ class RemoteMgrTest : public ::testing::Test {
|
||||
EagerContext* ctx_;
|
||||
};
|
||||
|
||||
TEST_F(RemoteMgrTest, LocalTensorHandle) {
|
||||
TestRemoteMgr remote_mgr(true, ctx_);
|
||||
Tensor t(DT_FLOAT, TensorShape({0}));
|
||||
|
||||
TensorHandle* handle;
|
||||
TF_ASSERT_OK(TensorHandle::CreateLocalHandle(t, &handle));
|
||||
EXPECT_EQ(nullptr, handle->device());
|
||||
EXPECT_EQ(local_device_, handle->DeviceOrHostCPU(ctx_));
|
||||
const uint64 op_id = remote_mgr.OpId();
|
||||
EXPECT_EQ(1, op_id);
|
||||
RemoteTensorHandle remote_handle;
|
||||
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
|
||||
handle, &remote_handle, handle->device(),
|
||||
handle->DeviceOrHostCPU(ctx_)->name()));
|
||||
EXPECT_EQ(2, remote_mgr.OpId());
|
||||
EXPECT_EQ(op_id, remote_handle.op_id());
|
||||
EXPECT_EQ(0, remote_handle.output_num());
|
||||
EXPECT_EQ(local_device_->name(), remote_handle.device());
|
||||
|
||||
TensorHandle* deserialized_handle;
|
||||
TF_ASSERT_OK(remote_mgr.DeserializeRemoteTensorHandle(remote_handle,
|
||||
&deserialized_handle));
|
||||
tensorflow::TensorHandle* h;
|
||||
TF_EXPECT_OK(remote_mgr.GetTensorHandle(
|
||||
RemoteTensorHandleInternal(remote_handle), &h));
|
||||
TF_ASSERT_OK(
|
||||
remote_mgr.DeleteTensorHandle(RemoteTensorHandleInternal(remote_handle)));
|
||||
EXPECT_FALSE(
|
||||
remote_mgr.GetTensorHandle(RemoteTensorHandleInternal(remote_handle), &h)
|
||||
.ok());
|
||||
|
||||
deserialized_handle->Unref();
|
||||
handle->Unref();
|
||||
}
|
||||
|
||||
TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
|
||||
RemoteMgr remote_mgr(false, ctx_);
|
||||
Tensor t(DT_FLOAT, TensorShape({0}));
|
||||
|
Loading…
Reference in New Issue
Block a user