Make SerializeRemoteTensorHandle block only when the remote op is a function, in order to still benefit from async execution.
PiperOrigin-RevId: 311423473 Change-Id: I87a3973ddf1954facb69c14499ce2fa07a9d6e99
This commit is contained in:
parent
e1b0e64119
commit
0ac3572e8d
tensorflow
c/eager
core
@ -434,6 +434,22 @@ string AddVariablesFunction() {
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
TFE_TensorHandle* is_initialized[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
|
||||
CHECK_EQ(1, num_retvals);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
|
||||
bool initialized = false;
|
||||
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
EXPECT_EQ(initialized, true);
|
||||
delete status;
|
||||
}
|
||||
|
||||
void TestFunctionWithPackedInput(const bool remote) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
@ -474,6 +490,12 @@ void TestFunctionWithPackedInput(const bool remote) {
|
||||
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||
|
||||
// Add a sync point in order to make sure that variables have been initialized
|
||||
// before the function execution starts.
|
||||
// TODO(b/155789951): Remove once b/155789951 is fixed.
|
||||
VarIsInitialized(ctx, h1);
|
||||
VarIsInitialized(ctx, h2);
|
||||
|
||||
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||
int num_replicas = 3;
|
||||
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||
|
@ -782,9 +782,15 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
}
|
||||
}
|
||||
auto* input_handle = remote_op->add_op_inputs()->mutable_remote_handle();
|
||||
// For a multi-device function, a remote RunComponentFunction request is
|
||||
// not sent through StreamingEnqueueAsync. It could arrive at a remote
|
||||
// worker before a remote execution request which produces an input of the
|
||||
// component function. So we wait until the remote input is ready before
|
||||
// serializing it.
|
||||
const bool wait_until_ready = op->is_function();
|
||||
TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
input, input_handle, input_device, *input_device_name,
|
||||
serialize_resource_dtype_and_shape));
|
||||
input, wait_until_ready, input_handle, input_device,
|
||||
*input_device_name, serialize_resource_dtype_and_shape));
|
||||
if (!input_handle->resource_dtypes_and_shapes().empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
input->AddResourceShapeMirror(op_device, input_handle->op_id(),
|
||||
|
@ -97,9 +97,11 @@ Status ExecuteNodeArgs::Init(
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
if (has_remote_inputs_) {
|
||||
const bool is_function = kernel->IsFunction();
|
||||
serialize_remote_handle_ =
|
||||
[ctx, &op_inputs](const FunctionArgIndex& index,
|
||||
eager::RemoteTensorHandle* handle) -> Status {
|
||||
[ctx, &op_inputs, is_function](
|
||||
const FunctionArgIndex& index,
|
||||
eager::RemoteTensorHandle* handle) -> Status {
|
||||
TensorHandle* h = op_inputs[index.index];
|
||||
if (op_inputs[index.index]->Type() == TensorHandle::PACKED) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -112,8 +114,14 @@ Status ExecuteNodeArgs::Init(
|
||||
"together.");
|
||||
}
|
||||
Device* device = absl::get<Device*>(variant_device);
|
||||
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(h, handle, device,
|
||||
device->name());
|
||||
// For a multi-device function, a remote RunComponentFunction request is
|
||||
// not sent through StreamingEnqueueAsync. It could arrive at a remote
|
||||
// worker before a remote execution request which produces an input of the
|
||||
// component function. So we wait until the remote input is ready before
|
||||
// serializing it.
|
||||
const bool wait_util_ready = is_function;
|
||||
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
h, wait_util_ready, handle, device, device->name());
|
||||
};
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
@ -705,8 +705,8 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
Status TensorHandle::RemoteAddressUntilReady(const Device* d, int64* op_id,
|
||||
int32* output_num) const {
|
||||
Status TensorHandle::RemoteAddress(const Device* d, const bool wait_until_ready,
|
||||
int64* op_id, int32* output_num) const {
|
||||
DVLOG(3) << "RemoteAddress on TensorHandle: " << this << " device: " << d
|
||||
<< " " << d->name();
|
||||
|
||||
@ -714,7 +714,8 @@ Status TensorHandle::RemoteAddressUntilReady(const Device* d, int64* op_id,
|
||||
tf_shared_lock l(mu_);
|
||||
auto mirror = remote_mirrors_.find(d->name());
|
||||
if (mirror != remote_mirrors_.end()) {
|
||||
return mirror->second.OpIdAndOutputNumUntilReady(op_id, output_num);
|
||||
return mirror->second.OpIdAndOutputNum(wait_until_ready, op_id,
|
||||
output_num);
|
||||
}
|
||||
|
||||
return errors::FailedPrecondition(
|
||||
@ -726,7 +727,7 @@ Status TensorHandle::RemoteAddressUntilReady(const Device* d, int64* op_id,
|
||||
}
|
||||
|
||||
auto& data = absl::get<RemoteTensorHandleData>(data_);
|
||||
return data.OpIdAndOutputNumUntilReady(op_id, output_num);
|
||||
return data.OpIdAndOutputNum(wait_until_ready, op_id, output_num);
|
||||
}
|
||||
|
||||
bool TensorHandle::HasRemoteMirror(const Device* d,
|
||||
|
@ -168,10 +168,11 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
||||
Status AddResourceShapeMirror(const Device* d, int64 op_id, int output_num,
|
||||
EagerContext* ctx);
|
||||
|
||||
// Return the op_id and output num if the handle refers to a remote tensor;
|
||||
// and blocks until the remote tensor is ready on the given remote worker.
|
||||
Status RemoteAddressUntilReady(const Device* d, int64* op_id,
|
||||
int32* output_num) const;
|
||||
// Return the op_id and output num if the handle refers to a remote tensor.
|
||||
// If wait_until_ready is true, block until the remote tensor is ready on the
|
||||
// given remote worker.
|
||||
Status RemoteAddress(const Device* d, const bool wait_until_ready,
|
||||
int64* op_id, int32* output_num) const;
|
||||
|
||||
// Called on an async remote tensor once it's shape has been determined. This
|
||||
// transitions the tensor handle from a non-ready to a ready state by
|
||||
|
@ -970,8 +970,9 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
|
||||
EXPECT_EQ(handle2->op_device()->name(), device2);
|
||||
int64 op_id;
|
||||
int32 output_num;
|
||||
TF_ASSERT_OK(handle2->RemoteAddressUntilReady(
|
||||
absl::get<Device*>(handle2->device()), &op_id, &output_num));
|
||||
TF_ASSERT_OK(handle2->RemoteAddress(absl::get<Device*>(handle2->device()),
|
||||
/*wait_until_ready=*/true, &op_id,
|
||||
&output_num));
|
||||
EXPECT_EQ(op_id, 2);
|
||||
EXPECT_EQ(output_num, 5);
|
||||
|
||||
|
@ -147,7 +147,8 @@ void RemoteCopyNode::StartSend() {
|
||||
request.set_context_id(ctx_->GetContextId());
|
||||
auto* remote_op = request.add_queue()->mutable_operation();
|
||||
status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
src_, remote_op->add_op_inputs()->mutable_remote_handle(),
|
||||
src_, /*wait_until_ready=*/false,
|
||||
remote_op->add_op_inputs()->mutable_remote_handle(),
|
||||
absl::get<Device*>(src_->device()),
|
||||
absl::get<Device*>(src_->DeviceOrHostCPU(*ctx_))->name());
|
||||
if (!status.ok()) {
|
||||
@ -316,7 +317,8 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
|
||||
(i == 0) && (h->dtype == DT_RESOURCE) &&
|
||||
(ctx->OnSameTask(src_device, target_device));
|
||||
TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
h, op->add_handles()->mutable_remote_handle(), src_device,
|
||||
h, /*wait_until_ready=*/false,
|
||||
op->add_handles()->mutable_remote_handle(), src_device,
|
||||
absl::get<Device*>(h->DeviceOrHostCPU(*ctx))->name(),
|
||||
serialize_resource_dtype_and_shape));
|
||||
} else {
|
||||
|
@ -74,6 +74,7 @@ Status RemoteMgr::GetMirroredResourceShape(
|
||||
}
|
||||
|
||||
Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
|
||||
const bool wait_until_ready,
|
||||
int64* op_id, int32* output_num) {
|
||||
// TODO(allenl): Consider supporting remote handles on custom devices.
|
||||
VariantDevice device = handle->device();
|
||||
@ -82,8 +83,8 @@ Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
|
||||
"Custom devices and remote execution are currently not supported "
|
||||
"together.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(handle->RemoteAddressUntilReady(absl::get<Device*>(device),
|
||||
op_id, output_num));
|
||||
TF_RETURN_IF_ERROR(handle->RemoteAddress(
|
||||
absl::get<Device*>(device), wait_until_ready, op_id, output_num));
|
||||
tensorflow::TensorHandle* h;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));
|
||||
@ -120,13 +121,15 @@ Status RemoteMgr::DeleteTensorHandle(
|
||||
}
|
||||
|
||||
Status RemoteMgr::SerializeRemoteTensorHandle(
|
||||
TensorHandle* in, RemoteTensorHandle* out, Device* device,
|
||||
const string& device_name, const bool serialize_resource_dtype_and_shape) {
|
||||
TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out,
|
||||
Device* device, const string& device_name,
|
||||
const bool serialize_resource_dtype_and_shape) {
|
||||
int64 op_id;
|
||||
int32 output_num;
|
||||
if (!in->RemoteAddressUntilReady(device, &op_id, &output_num).ok()) {
|
||||
if (!in->RemoteAddress(device, wait_until_ready, &op_id, &output_num).ok()) {
|
||||
tf_shared_lock l(remote_tensor_handle_mu_);
|
||||
TF_RETURN_IF_ERROR(GetRemoteTensorHandle(in, &op_id, &output_num));
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetRemoteTensorHandle(in, wait_until_ready, &op_id, &output_num));
|
||||
}
|
||||
out->Clear();
|
||||
out->set_op_id(op_id);
|
||||
|
@ -61,9 +61,11 @@ class RemoteMgr {
|
||||
}
|
||||
|
||||
// Serialize a remote TensorHandle to a RemoteTensorHandle.
|
||||
// If wait_until_ready is true, block until the remote handle is ready on a
|
||||
// remote worker.
|
||||
Status SerializeRemoteTensorHandle(
|
||||
TensorHandle* in, RemoteTensorHandle* out, Device* device,
|
||||
const string& device_name,
|
||||
TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out,
|
||||
Device* device, const string& device_name,
|
||||
const bool serialize_resource_dtype_and_shape = false);
|
||||
|
||||
// Deserialize a RemoteTensorHandle to a TensorHandle(local/remote).
|
||||
@ -83,7 +85,8 @@ class RemoteMgr {
|
||||
// Returns the op_id and output_num if the given local TensorHandle exists in
|
||||
// remote_tensor_handle_map_.
|
||||
Status GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
|
||||
int64* op_id, int32* output_num)
|
||||
const bool wait_until_ready, int64* op_id,
|
||||
int32* output_num)
|
||||
TF_SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_);
|
||||
|
||||
Status GetTensorHandleImpl(const RemoteTensorHandleInternal& remote_handle,
|
||||
|
@ -81,7 +81,8 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
|
||||
handle->SetRemoteShape(shape, remote_device_, ctx_->GetContextViewId()));
|
||||
RemoteTensorHandle remote_handle;
|
||||
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
|
||||
handle, &remote_handle, remote_device_, remote_device_->name()));
|
||||
handle, /*wait_until_ready=*/true, &remote_handle, remote_device_,
|
||||
remote_device_->name()));
|
||||
EXPECT_EQ(op_id, remote_handle.op_id());
|
||||
EXPECT_EQ(output_num, remote_handle.output_num());
|
||||
EXPECT_EQ(remote_device_->name(), remote_handle.device());
|
||||
@ -97,7 +98,8 @@ TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
|
||||
op_id, output_num, DT_FLOAT, remote_device_, ctx_);
|
||||
RemoteTensorHandle remote_handle;
|
||||
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
|
||||
handle, &remote_handle, remote_device_, remote_device_->name()));
|
||||
handle, /*wait_until_ready=*/true, &remote_handle, remote_device_,
|
||||
remote_device_->name()));
|
||||
EXPECT_EQ(op_id, remote_handle.op_id());
|
||||
EXPECT_EQ(output_num, remote_handle.output_num());
|
||||
EXPECT_EQ(remote_device_->name(), remote_handle.device());
|
||||
|
@ -194,9 +194,12 @@ string RemoteTensorHandleData::DebugString() const {
|
||||
" output_num: ", output_num_);
|
||||
}
|
||||
|
||||
Status RemoteTensorHandleData::OpIdAndOutputNumUntilReady(
|
||||
int64* op_id, int32* output_num) const {
|
||||
TF_RETURN_IF_ERROR(WaitReady("OpIdAndOutputNumUntilReady"));
|
||||
Status RemoteTensorHandleData::OpIdAndOutputNum(const bool wait_util_ready,
|
||||
int64* op_id,
|
||||
int32* output_num) const {
|
||||
if (wait_util_ready) {
|
||||
TF_RETURN_IF_ERROR(WaitReady("OpIdAndOutputNumUntilReady"));
|
||||
}
|
||||
*op_id = op_id_;
|
||||
*output_num = output_num_;
|
||||
return Status::OK();
|
||||
|
@ -50,9 +50,10 @@ class RemoteTensorHandleData {
|
||||
|
||||
string DebugString() const;
|
||||
|
||||
// Block until the remote tensor is ready on a remote worker and return the op
|
||||
// id and output num.
|
||||
Status OpIdAndOutputNumUntilReady(int64* op_id, int32* output_num) const;
|
||||
// Return the op id and output num. If wait_util_ready is true, block until
|
||||
// the remote tensor is ready on a remote worker.
|
||||
Status OpIdAndOutputNum(const bool wait_util_ready, int64* op_id,
|
||||
int32* output_num) const;
|
||||
|
||||
uint64 context_view_id() const { return context_view_id_; }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user