Add support for mirroring remote tenors locally
As part of this fix we disable async support for local mirrors until TensorHandle can support waiting on local mirrors. PiperOrigin-RevId: 296960921 Change-Id: I92be930841916b39d98a8edfe348a2a0813ef92b
This commit is contained in:
parent
aaf6f810db
commit
411185a8fe
@ -205,6 +205,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
@ -127,7 +128,7 @@ void TestRemoteExecute(bool async) {
|
||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async) {
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -166,10 +167,14 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
||||
auto* h1_task2 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
}
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
@ -177,6 +182,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!async) {
|
||||
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h1_task2->handle.get())
|
||||
->Handle();
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(op->GetInput(1), remote_arg);
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
@ -213,9 +229,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||
TestRemoteExecuteSilentCopies(false, true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true);
|
||||
TestRemoteExecuteSilentCopies(true, true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||
TestRemoteExecuteSilentCopies(false, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, false);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
|
@ -1057,7 +1057,10 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (mirror) {
|
||||
// TODO(gjn): Need to add support for async execution. Note if receiver
|
||||
// is local, we need to first add support in TensorHandle to wait on local
|
||||
// mirrors.
|
||||
if (mirror && !executor->Async()) {
|
||||
TF_RETURN_IF_ERROR(h->AddEmptyLocalMirror(d));
|
||||
h->Ref();
|
||||
*result = h;
|
||||
@ -1091,7 +1094,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
}
|
||||
bool sender_is_local = absl::get<Device*>(send_device)->IsLocal();
|
||||
|
||||
bool recver_is_local = device->IsLocal();
|
||||
bool receiver_is_local = device->IsLocal();
|
||||
|
||||
if (!executor->Async()) {
|
||||
// In sync mode, always clear error to maintain the same behavior as before.
|
||||
@ -1099,26 +1102,42 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
executor->ClearError();
|
||||
}
|
||||
|
||||
if (sender_is_local && recver_is_local) {
|
||||
if (sender_is_local && receiver_is_local) {
|
||||
return LocalEagerCopyToDevice(h, ctx, executor, device, mirror, result);
|
||||
} else {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
return errors::Unimplemented(
|
||||
"Eager's remote execution is not available on mobile devices.");
|
||||
#else // !IS_MOBILE_PLATFORM
|
||||
if (mirror) {
|
||||
if (h->HasRemoteMirror(device, ctx->GetContextViewId())) {
|
||||
uint64 recv_op_id = 0;
|
||||
if (receiver_is_local) {
|
||||
Device* d = ctx->CanonicalDevice(device);
|
||||
if (mirror && h->HasLocalMirror(d)) {
|
||||
h->Ref();
|
||||
*result = h;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
uint64 recv_op_id = 0;
|
||||
if (recver_is_local) {
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
|
||||
true, /* d= */ ctx->CanonicalDevice(device), /* op_device= */ device,
|
||||
/*resource_device=*/nullptr, h->dtype, ctx, result));
|
||||
|
||||
// TODO(gjn): Need to add support for async execution. Note if receiver
|
||||
// is local, we need to first add support in TensorHandle to wait on local
|
||||
// mirrors.
|
||||
if (mirror && !executor->Async()) {
|
||||
TF_RETURN_IF_ERROR(h->AddEmptyLocalMirror(d));
|
||||
h->Ref();
|
||||
*result = h;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
|
||||
true, /* d= */ d, /* op_device= */ device,
|
||||
/*resource_device=*/nullptr, h->dtype, ctx, result));
|
||||
}
|
||||
} else {
|
||||
if (mirror) {
|
||||
if (h->HasRemoteMirror(device, ctx->GetContextViewId())) {
|
||||
h->Ref();
|
||||
*result = h;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
string remote_task;
|
||||
if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
|
||||
return errors::InvalidArgument(
|
||||
@ -1139,6 +1158,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
std::move(tensor_handle_data), h->dtype, device, ctx, result));
|
||||
}
|
||||
}
|
||||
|
||||
auto node = absl::make_unique<eager::RemoteCopyNode>(
|
||||
ctx, executor, h, result[0], device, recv_op_id);
|
||||
Status s = executor->AddOrExecute(std::move(node));
|
||||
|
@ -27,28 +27,29 @@ Status ExecuteNodeArgs::Init(
|
||||
// be decremented once execution is complete.
|
||||
const int n_inputs = op_inputs.size();
|
||||
if (n_inputs > 0) {
|
||||
TensorHandle* const* op_inputs_array = &op_inputs[0];
|
||||
TensorValue* tensor_args_array = &tensor_args_[0];
|
||||
TensorHandle* const* op_inputs_flat = &op_inputs[0];
|
||||
TensorValue* tensor_args_flat = &tensor_args_[0];
|
||||
for (int i = 0; i < n_inputs; ++i) {
|
||||
TensorHandle* in = op_inputs_array[i];
|
||||
if (!in->IsRemote()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
in->TensorValue(&tensor_args_array[i],
|
||||
ctx->CanonicalDevice(kernel->InputDevice(i))));
|
||||
} else {
|
||||
if (!has_remote_inputs_) {
|
||||
has_remote_inputs_ = true;
|
||||
TensorHandle* in = op_inputs_flat[i];
|
||||
Device* d = kernel->InputDevice(i);
|
||||
Status s = in->TensorValue(&tensor_args_flat[i], ctx->CanonicalDevice(d));
|
||||
if (!s.ok()) {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
uint64 context_view_id = ctx->GetContextViewId();
|
||||
if (in->IsRemote() || in->HasRemoteMirror(d, context_view_id)) {
|
||||
if (!has_remote_inputs_) {
|
||||
has_remote_inputs_ = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
return s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
if (has_remote_inputs_) {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
return errors::Unimplemented(
|
||||
"Eager's function execution with remote inputs is not available on "
|
||||
"mobile devices.");
|
||||
#else // !IS_MOBILE_PLATFORM
|
||||
serialize_remote_handle_ =
|
||||
[ctx, &op_inputs](const int i,
|
||||
eager::RemoteTensorHandle* handle) -> Status {
|
||||
@ -63,8 +64,8 @@ Status ExecuteNodeArgs::Init(
|
||||
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
op_inputs[i], handle, device, device->name());
|
||||
};
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -334,9 +334,12 @@ Status TensorHandle::Tensor(const tensorflow::Tensor** t) const {
|
||||
|
||||
Status TensorHandle::TensorFromDevice(const Device* d,
|
||||
const tensorflow::Tensor** t) const {
|
||||
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorFromDevice"));
|
||||
|
||||
if (d == absl::get<Device*>(device_)) {
|
||||
if (is_remote_) {
|
||||
return errors::Internal("Invalid Tensor call on remote handle: ", this);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorFromDevice"));
|
||||
return tensor_handle_data_->Tensor(t);
|
||||
}
|
||||
|
||||
@ -348,6 +351,7 @@ Status TensorHandle::TensorFromDevice(const Device* d,
|
||||
|
||||
auto empty_mirror = empty_local_mirrors_.find(d);
|
||||
if (empty_mirror != empty_local_mirrors_.end()) {
|
||||
// TODO(gjn): Add support for waiting on local mirrors
|
||||
return errors::Internal("Attempted to get Tensor for empty mirror");
|
||||
}
|
||||
|
||||
@ -356,9 +360,13 @@ Status TensorHandle::TensorFromDevice(const Device* d,
|
||||
}
|
||||
|
||||
Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) {
|
||||
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorValue"));
|
||||
|
||||
if (d == absl::get<Device*>(device_)) {
|
||||
if (is_remote_) {
|
||||
return errors::Internal("Invalid TensorValue call on remote handle: ",
|
||||
this);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorValue"));
|
||||
return tensor_handle_data_->TensorValue(t);
|
||||
}
|
||||
|
||||
@ -370,6 +378,7 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) {
|
||||
|
||||
auto empty_mirror = empty_local_mirrors_.find(d);
|
||||
if (empty_mirror != empty_local_mirrors_.end()) {
|
||||
// TODO(gjn): Add support for waiting on local mirrors
|
||||
return errors::Internal("Attempted to get TensorValue for empty mirror");
|
||||
}
|
||||
|
||||
@ -532,6 +541,9 @@ bool TensorHandle::HasLocalMirror(const Device* d) const {
|
||||
}
|
||||
|
||||
Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
|
||||
DVLOG(3) << "AddEmptyLocalMirror on TensorHandle: " << this
|
||||
<< " device: " << d;
|
||||
|
||||
mutex_lock l(mu_);
|
||||
if (local_mirrors_.find(d) != local_mirrors_.end()) {
|
||||
return errors::Internal("Attempted to duplicate a local mirror.");
|
||||
|
Loading…
x
Reference in New Issue
Block a user