diff --git a/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc b/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc index e4a227e463f..c073dc1fd88 100644 --- a/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/eager/process_function_library_runtime.cc @@ -31,12 +31,6 @@ void EagerProcessFunctionLibraryRuntime::RunRemoteDevice( FunctionLibraryRuntime::Handle local_handle, gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) const { - if (!rets->empty()) { - done( - errors::Unimplemented("Remote outputs are not supported by " - "EagerClusterFunctionLibraryRuntime yet.")); - return; - } parent_->Run(opts, local_handle, args, rets, std::move(done)); } diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc index dfa35086659..f2f63a8fab5 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc @@ -120,15 +120,11 @@ void EagerClusterFunctionLibraryRuntime::Run( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) { - FunctionLibraryRuntime::Options opts_copy = opts; - if (!opts_copy.op_id.has_value()) { - opts_copy.op_id = ctx_->RemoteMgr()->NextOpId(); - } - std::vector function_args; - for (const auto& tensor : args) { - function_args.push_back(tensor); - } - Run(opts_copy, handle, function_args, rets, std::move(done)); + std::vector function_args; + for (const auto& tensor : args) { + function_args.push_back(tensor); + } + Run(opts, handle, function_args, rets, std::move(done)); } void EagerClusterFunctionLibraryRuntime::Run( @@ -165,11 +161,6 @@ void EagerClusterFunctionLibraryRuntime::Run( EagerOperation* op = function_data->op.get(); - if (!opts.op_id.has_value()) { - done( - errors::Internal("op_id is not set for remote function: ", op->Name())); - } - eager::EnqueueRequest* request = new eager::EnqueueRequest; request->set_context_id(context_id_); eager::Operation* remote_op = request->add_queue()->mutable_operation(); @@ -187,7 +178,11 @@ void EagerClusterFunctionLibraryRuntime::Run( // The remote component function should use the same op_id as its parent // multi-device function's in order to get the global unique op_id generated // by the master context. - remote_op->set_id(opts.op_id.value()); + if (opts.op_id.has_value()) { + remote_op->set_id(opts.op_id.value()); + } else { + remote_op->set_id(kInvalidRemoteOpId); + } remote_op->set_is_function(true); remote_op->set_is_component_function(true); remote_op->set_func_step_id(opts.step_id); @@ -203,15 +198,39 @@ void EagerClusterFunctionLibraryRuntime::Run( // disabled, Run() returns when the remote function execution completes, which // might be blocked by a non-enqueued function execution. EnqueueResponse* response = new EnqueueResponse; - eager_client->EnqueueAsync(request, response, - [op, request, response, done](const Status& s) { - for (auto handle : op->Inputs()) { - handle->Unref(); - } - done(s); - delete request; - delete response; - }); + eager_client->EnqueueAsync( + request, response, + [op, request, response, rets, done = std::move(done)](const Status& s) { + Status status = s; + auto cleanup = gtl::MakeCleanup([request, response, &status, &done] { + done(status); + delete request; + delete response; + }); + + for (auto handle : op->Inputs()) { + handle->Unref(); + } + if (!status.ok()) { + return; + } + if (response->queue_response_size() != 1) { + status.Update(errors::Internal( + "Expect that the size of response queue equals 1, but got: ", + response->queue_response_size())); + return; + } + for (const auto& tensor_proto : response->queue_response(0).tensor()) { + Tensor t; + if (t.FromProto(tensor_proto)) { + rets->push_back(std::move(t)); + } else { + status.Update(errors::Internal("Could not convert tensor proto: ", + tensor_proto.DebugString())); + return; + } + } + }); } void EagerClusterFunctionLibraryRuntime::CleanUp( diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index e9d73a35ff9..cf28e2680d8 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -327,6 +327,13 @@ Status EagerServiceImpl::CreateMasterContext( return Status::OK(); } +Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) { + const tensorflow::Tensor* t = nullptr; + TF_RETURN_IF_ERROR(handle->Tensor(&t)); + t->AsProtoTensorContent(proto); + return Status::OK(); +} + Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { const tensorflow::Tensor* t = nullptr; @@ -412,12 +419,21 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation, VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id(); TF_RETURN_IF_ERROR(EagerExecute(op.get(), retvals.data(), &num_retvals)); - eager_context->RemoteMgr()->AddOperationOutputs( - absl::MakeSpan(retvals.data(), num_retvals), operation.id()); - - for (int i = 0; i < num_retvals; i++) { - TF_RETURN_IF_ERROR( - TensorHandleShape(retvals[i], queue_response->add_shape())); + if (operation.id() == kInvalidRemoteOpId) { + // Copy the output tensors back along with the response, since the op id + // is invalid which cannot be added to RemoteMgr. + for (int i = 0; i < num_retvals; i++) { + TF_RETURN_IF_ERROR( + TensorHandleProto(retvals[i], queue_response->add_tensor())); + retvals[i]->Unref(); + } + } else { + eager_context->RemoteMgr()->AddOperationOutputs( + absl::MakeSpan(retvals.data(), num_retvals), operation.id()); + for (int i = 0; i < num_retvals; i++) { + TF_RETURN_IF_ERROR( + TensorHandleShape(retvals[i], queue_response->add_shape())); + } } return Status::OK(); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 73bc42be0c5..2006d0a4d5c 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -560,13 +560,8 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest { /*thread_pool=*/nullptr, eager_cluster_flr_.get()); } - void CheckOutputsAndClose(const int64 op_id) { - const tensorflow::Tensor* t = nullptr; - tensorflow::TensorHandle* tensor_handle; - TF_ASSERT_OK(eager_service_impl_.GetTensorHandle( - context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle)); - TF_ASSERT_OK(tensor_handle->Tensor(&t)); - auto actual = t->flat(); + void CheckOutputTensorAndClose(const Tensor& tensor) { + auto actual = tensor.flat(); EXPECT_EQ(4, actual.size()); EXPECT_EQ(7, actual(0)); EXPECT_EQ(10, actual(1)); @@ -581,6 +576,15 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest { &close_context_response)); } + void CheckOutputsAndClose(const int64 op_id) { + const tensorflow::Tensor* t = nullptr; + tensorflow::TensorHandle* tensor_handle; + TF_ASSERT_OK(eager_service_impl_.GetTensorHandle( + context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle)); + TF_ASSERT_OK(tensor_handle->Tensor(&t)); + CheckOutputTensorAndClose(*t); + } + protected: const string local_device_ = "/job:localhost/replica:0/task:0/device:CPU:0"; const string remote_device_ = "/job:localhost/replica:0/task:1/device:CPU:0"; @@ -649,8 +653,9 @@ TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) { CheckOutputsAndClose(op_id); } -// Test executes a remote function with a local tensor input. -TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) { +// Test executes a remote function with local input and output tensors. +TEST_F(FunctionWithRemoteInputsTest, + EagerClusterFLRTestWithLocalInputAndOutput) { Init(); // Instantiate MatMulFunction on remote_device. FunctionLibraryRuntime::Handle handle; @@ -681,11 +686,9 @@ TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) { context_id_, RemoteTensorHandleInternal(1, 0), &tensor_handle)); TF_ASSERT_OK(tensor_handle->Tensor(&input_tensor)); - // Send input_tensor to the remote device and execute MatMulFunction on the - // remote device. + // Send input_tensor to the remote device, execute MatMulFunction on the + // remote device, and send the output back. FunctionLibraryRuntime::Options opts; - const uint64 op_id = 2; - opts.op_id = op_id; Notification execute_done; std::vector inputs = {*input_tensor}; std::vector outputs; @@ -696,7 +699,8 @@ TEST_F(FunctionWithRemoteInputsTest, EagerClusterFLRTestWithLocalTensorInput) { }); execute_done.WaitForNotification(); TF_ASSERT_OK(status); - CheckOutputsAndClose(op_id); + EXPECT_EQ(outputs.size(), 1); + CheckOutputTensorAndClose(outputs.at(0)); } // Test executes a remote function through KernelAndDeviceFunc. diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.h b/tensorflow/core/distributed_runtime/eager/remote_mgr.h index d075345a027..54c987d4daa 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.h +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.h @@ -26,6 +26,8 @@ limitations under the License. namespace tensorflow { namespace eager { +const int64 kInvalidRemoteOpId = -1; + // This class manages the states required to setup an eager cluster. // TODO(fishx): Move remote state from context to this class. class RemoteMgr { diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 6f2913eae90..d57ca22b0d2 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -73,7 +73,12 @@ message QueueItem { } message QueueResponse { + // `shape` and `tensor` cannot be set in the same response. + // Shapes of output tensors for creating remote TensorHandles. repeated TensorShapeProto shape = 1; + + // Output tensors of a remote function. Set when Operation.id is invalid. + repeated TensorProto tensor = 2; } message CreateContextRequest { diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 2911e1ae83e..9c229974b05 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -862,7 +862,10 @@ cuda_py_test( ":def_function", ":remote", ":test", + "//tensorflow/python:dtypes", + "//tensorflow/python:functional_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:tensor_spec", "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index c159bd094f8..b32a773e894 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -31,11 +31,14 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import remote from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.training import server_lib @@ -167,6 +170,26 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase): with ops.device('/job:worker/task:0'): self.assertAllEqual(func(), 1) + @test_util.eager_lazy_remote_copy_on_and_off + def testRemoteCall(self): + + @def_function.function( + input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) + def _remote_fn(x): + return constant_op.constant(1) + x + + remote_fn = _remote_fn.get_concrete_function() + + @def_function.function + def func(x): + return functional_ops.remote_call( + args=[x], + Tout=[dtypes.int32], + f=remote_fn, + target='/job:worker/task:0') + + self.assertAllEqual(func(constant_op.constant(1)), [2]) + class RemoteAsyncTest(test.TestCase):