Make SerializeRemoteTensorHandle block only when the remote op is a function, in order to still benefit from async execution.

PiperOrigin-RevId: 311423473
Change-Id: I87a3973ddf1954facb69c14499ce2fa07a9d6e99
This commit is contained in:
Yujing Zhang 2020-05-13 16:07:02 -07:00 committed by TensorFlower Gardener
parent e1b0e64119
commit 0ac3572e8d
12 changed files with 88 additions and 35 deletions

View File

@ -434,6 +434,22 @@ string AddVariablesFunction() {
return def.SerializeAsString();
}
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
@ -474,6 +490,12 @@ void TestFunctionWithPackedInput(const bool remote) {
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed.
VarIsInitialized(ctx, h1);
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};

View File

@ -782,9 +782,15 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
}
}
auto* input_handle = remote_op->add_op_inputs()->mutable_remote_handle();
// For a multi-device function, a remote RunComponentFunction request is
// not sent through StreamingEnqueueAsync. It could arrive at a remote
// worker before a remote execution request which produces an input of the
// component function. So we wait until the remote input is ready before
// serializing it.
const bool wait_until_ready = op->is_function();
TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
input, input_handle, input_device, *input_device_name,
serialize_resource_dtype_and_shape));
input, wait_until_ready, input_handle, input_device,
*input_device_name, serialize_resource_dtype_and_shape));
if (!input_handle->resource_dtypes_and_shapes().empty()) {
TF_RETURN_IF_ERROR(
input->AddResourceShapeMirror(op_device, input_handle->op_id(),

View File

@ -97,9 +97,11 @@ Status ExecuteNodeArgs::Init(
#if !defined(IS_MOBILE_PLATFORM)
if (has_remote_inputs_) {
const bool is_function = kernel->IsFunction();
serialize_remote_handle_ =
[ctx, &op_inputs](const FunctionArgIndex& index,
eager::RemoteTensorHandle* handle) -> Status {
[ctx, &op_inputs, is_function](
const FunctionArgIndex& index,
eager::RemoteTensorHandle* handle) -> Status {
TensorHandle* h = op_inputs[index.index];
if (op_inputs[index.index]->Type() == TensorHandle::PACKED) {
TF_RETURN_IF_ERROR(
@ -112,8 +114,14 @@ Status ExecuteNodeArgs::Init(
"together.");
}
Device* device = absl::get<Device*>(variant_device);
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(h, handle, device,
device->name());
// For a multi-device function, a remote RunComponentFunction request is
// not sent through StreamingEnqueueAsync. It could arrive at a remote
// worker before a remote execution request which produces an input of the
// component function. So we wait until the remote input is ready before
// serializing it.
const bool wait_util_ready = is_function;
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
h, wait_util_ready, handle, device, device->name());
};
}
#endif // !IS_MOBILE_PLATFORM

View File

@ -705,8 +705,8 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
}
#if !defined(IS_MOBILE_PLATFORM)
Status TensorHandle::RemoteAddressUntilReady(const Device* d, int64* op_id,
int32* output_num) const {
Status TensorHandle::RemoteAddress(const Device* d, const bool wait_until_ready,
int64* op_id, int32* output_num) const {
DVLOG(3) << "RemoteAddress on TensorHandle: " << this << " device: " << d
<< " " << d->name();
@ -714,7 +714,8 @@ Status TensorHandle::RemoteAddressUntilReady(const Device* d, int64* op_id,
tf_shared_lock l(mu_);
auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) {
return mirror->second.OpIdAndOutputNumUntilReady(op_id, output_num);
return mirror->second.OpIdAndOutputNum(wait_until_ready, op_id,
output_num);
}
return errors::FailedPrecondition(
@ -726,7 +727,7 @@ Status TensorHandle::RemoteAddressUntilReady(const Device* d, int64* op_id,
}
auto& data = absl::get<RemoteTensorHandleData>(data_);
return data.OpIdAndOutputNumUntilReady(op_id, output_num);
return data.OpIdAndOutputNum(wait_until_ready, op_id, output_num);
}
bool TensorHandle::HasRemoteMirror(const Device* d,

View File

@ -168,10 +168,11 @@ class TensorHandle : public AbstractTensorHandleInterface,
Status AddResourceShapeMirror(const Device* d, int64 op_id, int output_num,
EagerContext* ctx);
// Return the op_id and output num if the handle refers to a remote tensor;
// and blocks until the remote tensor is ready on the given remote worker.
Status RemoteAddressUntilReady(const Device* d, int64* op_id,
int32* output_num) const;
// Return the op_id and output num if the handle refers to a remote tensor.
// If wait_until_ready is true, block until the remote tensor is ready on the
// given remote worker.
Status RemoteAddress(const Device* d, const bool wait_until_ready,
int64* op_id, int32* output_num) const;
// Called on an async remote tensor once it's shape has been determined. This
// transitions the tensor handle from a non-ready to a ready state by

View File

@ -970,8 +970,9 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
EXPECT_EQ(handle2->op_device()->name(), device2);
int64 op_id;
int32 output_num;
TF_ASSERT_OK(handle2->RemoteAddressUntilReady(
absl::get<Device*>(handle2->device()), &op_id, &output_num));
TF_ASSERT_OK(handle2->RemoteAddress(absl::get<Device*>(handle2->device()),
/*wait_until_ready=*/true, &op_id,
&output_num));
EXPECT_EQ(op_id, 2);
EXPECT_EQ(output_num, 5);

View File

@ -147,7 +147,8 @@ void RemoteCopyNode::StartSend() {
request.set_context_id(ctx_->GetContextId());
auto* remote_op = request.add_queue()->mutable_operation();
status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
src_, remote_op->add_op_inputs()->mutable_remote_handle(),
src_, /*wait_until_ready=*/false,
remote_op->add_op_inputs()->mutable_remote_handle(),
absl::get<Device*>(src_->device()),
absl::get<Device*>(src_->DeviceOrHostCPU(*ctx_))->name());
if (!status.ok()) {
@ -316,7 +317,8 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
(i == 0) && (h->dtype == DT_RESOURCE) &&
(ctx->OnSameTask(src_device, target_device));
TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
h, op->add_handles()->mutable_remote_handle(), src_device,
h, /*wait_until_ready=*/false,
op->add_handles()->mutable_remote_handle(), src_device,
absl::get<Device*>(h->DeviceOrHostCPU(*ctx))->name(),
serialize_resource_dtype_and_shape));
} else {

View File

@ -74,6 +74,7 @@ Status RemoteMgr::GetMirroredResourceShape(
}
Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
const bool wait_until_ready,
int64* op_id, int32* output_num) {
// TODO(allenl): Consider supporting remote handles on custom devices.
VariantDevice device = handle->device();
@ -82,8 +83,8 @@ Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
"Custom devices and remote execution are currently not supported "
"together.");
}
TF_RETURN_IF_ERROR(handle->RemoteAddressUntilReady(absl::get<Device*>(device),
op_id, output_num));
TF_RETURN_IF_ERROR(handle->RemoteAddress(
absl::get<Device*>(device), wait_until_ready, op_id, output_num));
tensorflow::TensorHandle* h;
TF_RETURN_IF_ERROR(
GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));
@ -120,13 +121,15 @@ Status RemoteMgr::DeleteTensorHandle(
}
Status RemoteMgr::SerializeRemoteTensorHandle(
TensorHandle* in, RemoteTensorHandle* out, Device* device,
const string& device_name, const bool serialize_resource_dtype_and_shape) {
TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out,
Device* device, const string& device_name,
const bool serialize_resource_dtype_and_shape) {
int64 op_id;
int32 output_num;
if (!in->RemoteAddressUntilReady(device, &op_id, &output_num).ok()) {
if (!in->RemoteAddress(device, wait_until_ready, &op_id, &output_num).ok()) {
tf_shared_lock l(remote_tensor_handle_mu_);
TF_RETURN_IF_ERROR(GetRemoteTensorHandle(in, &op_id, &output_num));
TF_RETURN_IF_ERROR(
GetRemoteTensorHandle(in, wait_until_ready, &op_id, &output_num));
}
out->Clear();
out->set_op_id(op_id);

View File

@ -61,9 +61,11 @@ class RemoteMgr {
}
// Serialize a remote TensorHandle to a RemoteTensorHandle.
// If wait_until_ready is true, block until the remote handle is ready on a
// remote worker.
Status SerializeRemoteTensorHandle(
TensorHandle* in, RemoteTensorHandle* out, Device* device,
const string& device_name,
TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out,
Device* device, const string& device_name,
const bool serialize_resource_dtype_and_shape = false);
// Deserialize a RemoteTensorHandle to a TensorHandle(local/remote).
@ -83,7 +85,8 @@ class RemoteMgr {
// Returns the op_id and output_num if the given local TensorHandle exists in
// remote_tensor_handle_map_.
Status GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
int64* op_id, int32* output_num)
const bool wait_until_ready, int64* op_id,
int32* output_num)
TF_SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_);
Status GetTensorHandleImpl(const RemoteTensorHandleInternal& remote_handle,

View File

@ -81,7 +81,8 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
handle->SetRemoteShape(shape, remote_device_, ctx_->GetContextViewId()));
RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, &remote_handle, remote_device_, remote_device_->name()));
handle, /*wait_until_ready=*/true, &remote_handle, remote_device_,
remote_device_->name()));
EXPECT_EQ(op_id, remote_handle.op_id());
EXPECT_EQ(output_num, remote_handle.output_num());
EXPECT_EQ(remote_device_->name(), remote_handle.device());
@ -97,7 +98,8 @@ TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
op_id, output_num, DT_FLOAT, remote_device_, ctx_);
RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, &remote_handle, remote_device_, remote_device_->name()));
handle, /*wait_until_ready=*/true, &remote_handle, remote_device_,
remote_device_->name()));
EXPECT_EQ(op_id, remote_handle.op_id());
EXPECT_EQ(output_num, remote_handle.output_num());
EXPECT_EQ(remote_device_->name(), remote_handle.device());

View File

@ -194,9 +194,12 @@ string RemoteTensorHandleData::DebugString() const {
" output_num: ", output_num_);
}
Status RemoteTensorHandleData::OpIdAndOutputNumUntilReady(
int64* op_id, int32* output_num) const {
TF_RETURN_IF_ERROR(WaitReady("OpIdAndOutputNumUntilReady"));
Status RemoteTensorHandleData::OpIdAndOutputNum(const bool wait_util_ready,
int64* op_id,
int32* output_num) const {
if (wait_util_ready) {
TF_RETURN_IF_ERROR(WaitReady("OpIdAndOutputNumUntilReady"));
}
*op_id = op_id_;
*output_num = output_num_;
return Status::OK();

View File

@ -50,9 +50,10 @@ class RemoteTensorHandleData {
string DebugString() const;
// Block until the remote tensor is ready on a remote worker and return the op
// id and output num.
Status OpIdAndOutputNumUntilReady(int64* op_id, int32* output_num) const;
// Return the op id and output num. If wait_util_ready is true, block until
// the remote tensor is ready on a remote worker.
Status OpIdAndOutputNum(const bool wait_util_ready, int64* op_id,
int32* output_num) const;
uint64 context_view_id() const { return context_view_id_; }