diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 3a6c2eef1fe..30ae001caf7 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -205,6 +205,7 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/platform:casts", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index 2f363a4f9a4..eb6b234e3df 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/cluster.pb.h" @@ -127,7 +128,7 @@ void TestRemoteExecute(bool async) { TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } -void TestRemoteExecuteSilentCopies(bool async) { +void TestRemoteExecuteSilentCopies(bool async, bool remote) { tensorflow::ServerDef server_def = GetServerDef(3); // This server def has the task index set to 0. @@ -166,10 +167,14 @@ void TestRemoteExecuteSilentCopies(bool async) { auto* h1_task2 = TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandleEnableImplicitMirroring(h1_task2, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); // Handles are on task0 (local), and task2, but op is on task1. TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2); - TFE_OpSetDevice(matmul, task1_name, status); + if (remote) { + TFE_OpSetDevice(matmul, task1_name, status); + } EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* retvals[1]; @@ -177,6 +182,17 @@ void TestRemoteExecuteSilentCopies(bool async) { TFE_Execute(matmul, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + // TODO(gjn): Add support for waiting on async local mirrors + if (!async) { + auto remote_arg = tensorflow::down_cast( + h1_task2->handle.get()) + ->Handle(); + auto op = tensorflow::down_cast( + matmul->operation.get()); + // The input handles should never change since they have been mirrored. + ASSERT_EQ(op->GetInput(1), remote_arg); + } + auto* retval_task0 = TFE_TensorHandleCopyToDevice( retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -213,9 +229,17 @@ void TestRemoteExecuteSilentCopies(bool async) { worker_server2.release(); } -TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); } +TEST(CAPI, RemoteExecuteSilentCopies) { + TestRemoteExecuteSilentCopies(false, true); +} TEST(CAPI, RemoteExecuteSilentCopiesAsync) { - TestRemoteExecuteSilentCopies(true); + TestRemoteExecuteSilentCopies(true, true); +} +TEST(CAPI, RemoteExecuteSilentCopiesLocal) { + TestRemoteExecuteSilentCopies(false, false); +} +TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) { + TestRemoteExecuteSilentCopies(true, false); } void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index e0b24b826db..3496a39714d 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -1057,7 +1057,10 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, return Status::OK(); } - if (mirror) { + // TODO(gjn): Need to add support for async execution. Note if receiver + // is local, we need to first add support in TensorHandle to wait on local + // mirrors. + if (mirror && !executor->Async()) { TF_RETURN_IF_ERROR(h->AddEmptyLocalMirror(d)); h->Ref(); *result = h; @@ -1091,7 +1094,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, } bool sender_is_local = absl::get(send_device)->IsLocal(); - bool recver_is_local = device->IsLocal(); + bool receiver_is_local = device->IsLocal(); if (!executor->Async()) { // In sync mode, always clear error to maintain the same behavior as before. @@ -1099,26 +1102,42 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, executor->ClearError(); } - if (sender_is_local && recver_is_local) { + if (sender_is_local && receiver_is_local) { return LocalEagerCopyToDevice(h, ctx, executor, device, mirror, result); } else { #if defined(IS_MOBILE_PLATFORM) return errors::Unimplemented( "Eager's remote execution is not available on mobile devices."); #else // !IS_MOBILE_PLATFORM - if (mirror) { - if (h->HasRemoteMirror(device, ctx->GetContextViewId())) { + uint64 recv_op_id = 0; + if (receiver_is_local) { + Device* d = ctx->CanonicalDevice(device); + if (mirror && h->HasLocalMirror(d)) { h->Ref(); *result = h; return Status::OK(); } - } - uint64 recv_op_id = 0; - if (recver_is_local) { - TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( - true, /* d= */ ctx->CanonicalDevice(device), /* op_device= */ device, - /*resource_device=*/nullptr, h->dtype, ctx, result)); + + // TODO(gjn): Need to add support for async execution. Note if receiver + // is local, we need to first add support in TensorHandle to wait on local + // mirrors. + if (mirror && !executor->Async()) { + TF_RETURN_IF_ERROR(h->AddEmptyLocalMirror(d)); + h->Ref(); + *result = h; + } else { + TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( + true, /* d= */ d, /* op_device= */ device, + /*resource_device=*/nullptr, h->dtype, ctx, result)); + } } else { + if (mirror) { + if (h->HasRemoteMirror(device, ctx->GetContextViewId())) { + h->Ref(); + *result = h; + return Status::OK(); + } + } string remote_task; if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) { return errors::InvalidArgument( @@ -1139,6 +1158,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, std::move(tensor_handle_data), h->dtype, device, ctx, result)); } } + auto node = absl::make_unique( ctx, executor, h, result[0], device, recv_op_id); Status s = executor->AddOrExecute(std::move(node)); diff --git a/tensorflow/core/common_runtime/eager/execute_node.cc b/tensorflow/core/common_runtime/eager/execute_node.cc index c053420fe83..d523dc20084 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.cc +++ b/tensorflow/core/common_runtime/eager/execute_node.cc @@ -27,28 +27,29 @@ Status ExecuteNodeArgs::Init( // be decremented once execution is complete. const int n_inputs = op_inputs.size(); if (n_inputs > 0) { - TensorHandle* const* op_inputs_array = &op_inputs[0]; - TensorValue* tensor_args_array = &tensor_args_[0]; + TensorHandle* const* op_inputs_flat = &op_inputs[0]; + TensorValue* tensor_args_flat = &tensor_args_[0]; for (int i = 0; i < n_inputs; ++i) { - TensorHandle* in = op_inputs_array[i]; - if (!in->IsRemote()) { - TF_RETURN_IF_ERROR( - in->TensorValue(&tensor_args_array[i], - ctx->CanonicalDevice(kernel->InputDevice(i)))); - } else { - if (!has_remote_inputs_) { - has_remote_inputs_ = true; + TensorHandle* in = op_inputs_flat[i]; + Device* d = kernel->InputDevice(i); + Status s = in->TensorValue(&tensor_args_flat[i], ctx->CanonicalDevice(d)); + if (!s.ok()) { +#if !defined(IS_MOBILE_PLATFORM) + uint64 context_view_id = ctx->GetContextViewId(); + if (in->IsRemote() || in->HasRemoteMirror(d, context_view_id)) { + if (!has_remote_inputs_) { + has_remote_inputs_ = true; + } + continue; } +#endif + return s; } } } +#if !defined(IS_MOBILE_PLATFORM) if (has_remote_inputs_) { -#if defined(IS_MOBILE_PLATFORM) - return errors::Unimplemented( - "Eager's function execution with remote inputs is not available on " - "mobile devices."); -#else // !IS_MOBILE_PLATFORM serialize_remote_handle_ = [ctx, &op_inputs](const int i, eager::RemoteTensorHandle* handle) -> Status { @@ -63,8 +64,8 @@ Status ExecuteNodeArgs::Init( return ctx->RemoteMgr()->SerializeRemoteTensorHandle( op_inputs[i], handle, device, device->name()); }; -#endif // !IS_MOBILE_PLATFORM } +#endif // !IS_MOBILE_PLATFORM return Status::OK(); } diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index fb13e81d37d..4148400acae 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -334,9 +334,12 @@ Status TensorHandle::Tensor(const tensorflow::Tensor** t) const { Status TensorHandle::TensorFromDevice(const Device* d, const tensorflow::Tensor** t) const { - TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorFromDevice")); - if (d == absl::get(device_)) { + if (is_remote_) { + return errors::Internal("Invalid Tensor call on remote handle: ", this); + } + + TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorFromDevice")); return tensor_handle_data_->Tensor(t); } @@ -348,6 +351,7 @@ Status TensorHandle::TensorFromDevice(const Device* d, auto empty_mirror = empty_local_mirrors_.find(d); if (empty_mirror != empty_local_mirrors_.end()) { + // TODO(gjn): Add support for waiting on local mirrors return errors::Internal("Attempted to get Tensor for empty mirror"); } @@ -356,9 +360,13 @@ Status TensorHandle::TensorFromDevice(const Device* d, } Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) { - TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorValue")); - if (d == absl::get(device_)) { + if (is_remote_) { + return errors::Internal("Invalid TensorValue call on remote handle: ", + this); + } + + TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorValue")); return tensor_handle_data_->TensorValue(t); } @@ -370,6 +378,7 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) { auto empty_mirror = empty_local_mirrors_.find(d); if (empty_mirror != empty_local_mirrors_.end()) { + // TODO(gjn): Add support for waiting on local mirrors return errors::Internal("Attempted to get TensorValue for empty mirror"); } @@ -532,6 +541,9 @@ bool TensorHandle::HasLocalMirror(const Device* d) const { } Status TensorHandle::AddEmptyLocalMirror(const Device* d) { + DVLOG(3) << "AddEmptyLocalMirror on TensorHandle: " << this + << " device: " << d; + mutex_lock l(mu_); if (local_mirrors_.find(d) != local_mirrors_.end()) { return errors::Internal("Attempted to duplicate a local mirror.");