diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 35d4177f3da..24582147479 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -551,8 +551,9 @@ Status GetOrCreateKernelAndDevice( ctx.GetCollectiveExecutorHandle(), ctx.HostCPU())); } - TF_RETURN_IF_ERROR( - kernel->Init({ctx.LogDevicePlacement()}, ndef, graph_collector)); + TF_RETURN_IF_ERROR(kernel->Init( + {ctx.LogDevicePlacement(), ctx.LazyCopyFunctionRemoteInputs()}, ndef, + graph_collector)); if (op->is_function()) { ctx.AddKernelToCache(cache_key, kernel.get()); diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 00d832365e9..5f0dce21e8e 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -223,7 +223,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx, Status KernelAndDeviceFunc::Init(const Context& ctx, const NodeDef& ndef, GraphCollector* graph_collector) { TF_RETURN_IF_ERROR(InstantiateFunc(ctx, ndef, graph_collector)); - return pflr_->GetOutputDevices(handle_, &output_devices_); + return pflr_->GetOutputDevices(handle_, &output_devices_, + ctx.eager_lazy_copy); } namespace { diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index 7bf4afbaf24..0a765510d7b 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -97,6 +97,7 @@ class KernelAndDevice : public core::RefCounted { public: struct Context { bool log_device_placement = false; + bool eager_lazy_copy = false; }; // Populates this with a kernel appropriate for 'ndef'. diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 73450aa635f..ac3343e5a61 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -466,18 +466,6 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets( << " src_device: " << *src_device << " colo group: " << colocation_group; } - // If colocation_group is not set and output producing node is assigned - // to a remote device, colocate the retval node with its input node. - // TODO(yujingzhang): Remove this when we support outputting tensors on - // remote devices. - const bool remote_src_device = - !src_device->empty() && GetFLR(*src_device) == nullptr; - if (colocation_group.empty() && remote_src_device) { - colocation_group = - absl::StrCat(kColocationGroupPrefix, it->src()->name()); - VLOG(3) << "Considering src: " << src_node->name() - << " colo group: " << colocation_group; - } // If resource is produced by a function call node, we can't trust // source node device assignment, because multi-device functions can @@ -510,6 +498,20 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets( "Unable to find any devices for spec ", *src_device); } } else if (matching_devices.size() != 1) { + bool on_same_task = true; + for (int i = 1; i < matching_devices.size(); ++i) { + if (!DeviceNameUtils::IsSameAddressSpace( + matching_devices.at(0)->parsed_name(), + matching_devices.at(i)->parsed_name())) { + on_same_task = false; + break; + } + } + // If the src node of an output is assigned to a address space (e.g. + // py_func), rely on placer to assign a device to the output. + if (on_same_task) { + continue; + } // Convert a vector of devices to a string. // Using absl::StrJoin did not work in Android builds. string devices = "["; @@ -523,6 +525,7 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets( devices.append("]"); return errors::InvalidArgument( + *src_device, "When FunctionLibraryRuntime::Options.output_devices are " "not specified for a multi-device function, the device " "specification on the output node must match exactly one " @@ -968,6 +971,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( Status s = flr->Instantiate(unique_name, attrs, opts, component_handle); done(s); } else { + opts.ret_indices = comp_data->ret_indices; // Initialize remote function asynchronously. InstantiateRemote(unique_name, attrs, opts, component_handle, done); } @@ -988,9 +992,9 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( } Status ProcessFunctionLibraryRuntime::GetOutputDevices( - FunctionLibraryRuntime::Handle handle, - std::vector* output_devices) const { - const MultiDeviceFunctionData* data = IsMultiDevice(handle); + FunctionLibraryRuntime::Handle handle, std::vector* output_devices, + const bool eager_lazy_copy) const { + MultiDeviceFunctionData* data = IsMultiDevice(handle); if (data == nullptr) { return errors::InvalidArgument( "Failed for find multi-device function handle ", handle); @@ -1008,6 +1012,19 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices( Device* target_device = nullptr; Device* host = nullptr; if (target_flr == nullptr) { + if (!eager_lazy_copy) { + return errors::Unimplemented( + "Currently, outputting tensors on remote devices is not supported." + "The ", + comp_data.ret_indices[0], + "-th return value of the function outputs to target_device: ", + target, + " Please copy the tensor to local device explicitly using " + "tf.identity and return the new Tensor instead."); + } + if (!data->has_remote_outputs) { + data->has_remote_outputs = true; + } target_device = device_set()->FindDeviceByName(target); string remote_host; TF_RETURN_IF_ERROR( @@ -1607,7 +1624,12 @@ void ProcessFunctionLibraryRuntime::Run( FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) const { - if (!args.HasRemoteOrPackedInputs()) { + bool has_remote_outputs = false; + const MultiDeviceFunctionData* data = IsMultiDevice(handle); + if (data != nullptr) { + has_remote_outputs = data->has_remote_outputs; + } + if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) { const std::vector local_inputs = args.GetLocalTensors(); std::vector* tensor_rets = new std::vector; return Run( diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 69cd974b124..a882f5406d3 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -151,7 +151,8 @@ class ProcessFunctionLibraryRuntime { // is set to the device backing the resource. // REQUIRES: `handle` identifies a multi-device function. Status GetOutputDevices(FunctionLibraryRuntime::Handle handle, - std::vector* output_devices) const; + std::vector* output_devices, + const bool eager_lazy_copy) const; // Returns true if function with handle `handle` was instantiated on device // `device_name`. Returns false for multi-device functions. @@ -271,7 +272,8 @@ class ProcessFunctionLibraryRuntime { lib_def_(std::move(lib_def)), num_outputs_(num_outputs), ret_types_(std::move(ret_types)), - is_cross_process_(false) {} + is_cross_process_(false), + has_remote_outputs(false) {} const string function_name_; const string function_key_; @@ -285,6 +287,8 @@ class ProcessFunctionLibraryRuntime { // Indicates whether this function needs to execute cross process. bool is_cross_process_; + // Indicates whether this function has remote outputs. + bool has_remote_outputs; // Maps the device name to the information about the component function // be run on this device. diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 30512295a7e..505e0c305d6 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -105,6 +105,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:worker_proto_cc", + "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h index eb9ce64bcdb..4655bce44f9 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ +#include "absl/types/optional.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/framework/function.h" diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index c27758cbb44..fb9808b80cf 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -44,6 +44,7 @@ cc_library( "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:worker_session", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc index 0e0cd808504..e9801d65b49 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc @@ -96,14 +96,16 @@ void EagerClusterFunctionLibraryRuntime::Instantiate( .ToProto(); StripDefaultAttributesInRegisterFunctionOp(register_function); + const absl::optional>& ret_indices = options.ret_indices; eager_client->EnqueueAsync( /*call_opts=*/nullptr, request.get(), response.get(), [this, request, response, handle, released_op = released_op.release(), - target, eager_client = eager_client.get(), done](const Status& s) { + target, ret_indices, eager_client = eager_client.get(), + done](const Status& s) { { mutex_lock l(mu_); *handle = function_data_.size(); - function_data_.emplace_back(target, eager_client, + function_data_.emplace_back(target, ret_indices, eager_client, absl::WrapUnique(released_op)); } done(s); @@ -168,6 +170,12 @@ void EagerClusterFunctionLibraryRuntime::Run( request->set_context_id(context_id_); eager::Operation* remote_op = request->mutable_operation(); + if (function_data->ret_indices.has_value()) { + for (const int ret_index : function_data->ret_indices.value()) { + request->add_output_num(ret_index); + } + } + for (const auto& arg : args) { if (arg.index() == 0) { absl::get(arg).AsProtoTensorContent( diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h index 6e60ee0b13d..01e864053d1 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ +#include "absl/types/optional.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" @@ -84,12 +85,15 @@ class EagerClusterFunctionLibraryRuntime struct FunctionData { const string target; + const absl::optional> ret_indices; core::RefCountPtr eager_client; std::unique_ptr op; - FunctionData(const string& target, EagerClient* eager_client, - std::unique_ptr op) + FunctionData(const string& target, + const absl::optional>& ret_indices, + EagerClient* eager_client, std::unique_ptr op) : target(target), + ret_indices(ret_indices), eager_client(core::RefCountPtr(eager_client)), op(std::move(op)) { eager_client->Ref(); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 0e4eb9cf1dc..c3ed312428b 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -171,7 +171,8 @@ Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { Status AddOpRetvalsToResponse( EagerContext* eager_context, int op_id, int num_retvals, - TensorHandle** retvals, std::function add_tensor_proto_fn, + const std::vector& output_nums, TensorHandle** retvals, + std::function add_tensor_proto_fn, std::function add_shape_proto_fn, std::function add_device_fn = nullptr) { if (op_id == kInvalidRemoteOpId) { @@ -195,7 +196,9 @@ Status AddOpRetvalsToResponse( if (is_remote) { retvals[i]->Unref(); } else { - eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id, i); + const int output_num = output_nums.empty() ? i : output_nums.at(i); + eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id, + output_num); } } } @@ -474,6 +477,10 @@ void EagerServiceImpl::RunComponentFunction( auto* retvals = new absl::FixedArray(*num_retvals); VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op " << operation.id(); + std::vector output_nums; + for (const int32 output_num : request->output_num()) { + output_nums.push_back(output_num); + } auto cm = std::make_shared(); op->SetCancellationManager(cm.get()); @@ -482,8 +489,8 @@ void EagerServiceImpl::RunComponentFunction( context->Ref(); EagerLocalExecuteAsync( op, retvals->data(), num_retvals, - [op, op_id = operation.id(), num_retvals, retvals, cm, call_opts, - response, eager_context, context, + [op, op_id = operation.id(), num_retvals, retvals, output_nums, cm, + call_opts, response, eager_context, context, done = std::move(done)](const Status& status) { call_opts->ClearCancelCallback(); auto wrapped_done = [&](const Status& status) { @@ -500,7 +507,7 @@ void EagerServiceImpl::RunComponentFunction( // The output device of a component function is the component device // which is known on the default device of it's parent function. wrapped_done(AddOpRetvalsToResponse( - eager_context, op_id, *num_retvals, retvals->data(), + eager_context, op_id, *num_retvals, output_nums, retvals->data(), [response] { return response->add_tensor(); }, [response] { return response->add_shape(); })); }); @@ -539,8 +546,8 @@ Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts, } return AddOpRetvalsToResponse( - eager_context, operation.id(), num_retvals, retvals.data(), - [queue_response] { return queue_response->add_tensor(); }, + eager_context, operation.id(), num_retvals, /*output_nums=*/{}, + retvals.data(), [queue_response] { return queue_response->add_tensor(); }, [queue_response] { return queue_response->add_shape(); }, std::move(add_device_fn)); } diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 2e603a298ba..700cea117de 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -224,10 +224,11 @@ void AddOperationToRunComponentFunctionRequest( const std::vector>>& inputs, const std::unordered_map& attrs, const string& device, - RunComponentFunctionRequest* request) { + const int output_num, RunComponentFunctionRequest* request) { auto* operation = request->mutable_operation(); operation->set_is_function(true); operation->set_is_component_function(true); + request->add_output_num(output_num); BuildOperation(operation, id, name, inputs, attrs, device); } @@ -610,10 +611,12 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest { RunComponentFunctionRequest run_comp_func_request; run_comp_func_request.set_context_id(context_id); RunComponentFunctionResponse run_comp_func_response; + const int output_num = 5; AddOperationToRunComponentFunctionRequest( 2, function_name, {std::make_pair(1, 0)}, std::unordered_map(), - "/job:localhost/replica:0/task:0/device:CPU:0", &run_comp_func_request); + "/job:localhost/replica:0/task:0/device:CPU:0", output_num, + &run_comp_func_request); CallOptions call_opts; Notification n; @@ -636,7 +639,8 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest { const tensorflow::Tensor* t = nullptr; tensorflow::TensorHandle* tensor_handle; TF_ASSERT_OK(eager_service_impl.GetTensorHandle( - context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle)); + context_id, RemoteTensorHandleInternal(2, output_num), + &tensor_handle)); TF_ASSERT_OK(tensor_handle->Tensor(&t)); auto actual = t->flat(); diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index c7e6e2d158c..3c7c09eee37 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -612,6 +612,9 @@ class FunctionLibraryRuntime { // infer correct device. std::vector output_devices; + // If set, it indicates the original output indices of a component function. + absl::optional> ret_indices = absl::nullopt; + // Maps from a CompositeDevice name to a list of underlying physical // devices. absl::flat_hash_map*> composite_devices; diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 03f8357276f..204acf6b1df 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -180,6 +180,9 @@ message RunComponentFunctionRequest { fixed64 context_id = 1; Operation operation = 2; + + // The output indices of its parent function. + repeated int32 output_num = 3; } message RunComponentFunctionResponse { diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index c661ed98bf5..429068149b1 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -92,7 +92,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3]) - @test_util.eager_lazy_remote_copy_on_and_off def testMultiDeviceFunctionRemoteOutput(self): with ops.device('/job:worker/replica:0/task:0/cpu:0'): variable_b = variables.Variable(1) @@ -101,10 +100,15 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase): def remote_output(i): with ops.device('/job:worker/replica:0/task:0/cpu:0'): c = variable_b + 1 - return c, i + variable_b + return i + variable_b, c - self.assertAllEqual( - remote_output(constant_op.constant([1]))[0].numpy(), 2) + rets = remote_output(constant_op.constant([1])) + self.assertEqual(rets[0].backing_device, + '/job:localhost/replica:0/task:0/device:CPU:0') + self.assertEqual(rets[1].backing_device, + '/job:worker/replica:0/task:0/device:CPU:0') + self.assertAllEqual(rets[0].numpy(), [2]) + self.assertAllEqual(rets[1].numpy(), 2) def testMultiDeviceFunctionAmbiguousDevice(self): @@ -482,6 +486,25 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase): with ops.device('/job:worker/replica:0/task:0/device:GPU:0'): self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) + def testMultiDeviceFunctionRemoteOutput(self): + with ops.device('/job:worker/replica:0/task:1/cpu:0'): + variable_b = variables.Variable(1) + + @def_function.function + def remote_output(i): + with ops.device('/job:worker/replica:0/task:1/cpu:0'): + c = variable_b + 1 + return i + variable_b, c + + with ops.device('/job:worker/replica:0/task:0/cpu:0'): + rets = remote_output(constant_op.constant([1])) + self.assertEqual(rets[0].backing_device, + '/job:worker/replica:0/task:0/device:CPU:0') + self.assertEqual(rets[1].backing_device, + '/job:worker/replica:0/task:1/device:CPU:0') + self.assertAllEqual(rets[0].numpy(), [2]) + self.assertAllEqual(rets[1].numpy(), 2) + @test_util.eager_lazy_remote_copy_on_and_off def testMultiDeviceWhileLoopOnRemoteDevice(self): with ops.device('/job:worker/replica:0/task:1'):