Use the original output indices when adding a component function output to RemoteMgr.
PiperOrigin-RevId: 327507408 Change-Id: Ie33d8467aec3901340ac8edd8892f28811b92c2a
This commit is contained in:
parent
f06308bfc4
commit
699178a5d7
@ -551,8 +551,9 @@ Status GetOrCreateKernelAndDevice(
|
||||
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
kernel->Init({ctx.LogDevicePlacement()}, ndef, graph_collector));
|
||||
TF_RETURN_IF_ERROR(kernel->Init(
|
||||
{ctx.LogDevicePlacement(), ctx.LazyCopyFunctionRemoteInputs()}, ndef,
|
||||
graph_collector));
|
||||
|
||||
if (op->is_function()) {
|
||||
ctx.AddKernelToCache(cache_key, kernel.get());
|
||||
|
@ -223,7 +223,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
||||
Status KernelAndDeviceFunc::Init(const Context& ctx, const NodeDef& ndef,
|
||||
GraphCollector* graph_collector) {
|
||||
TF_RETURN_IF_ERROR(InstantiateFunc(ctx, ndef, graph_collector));
|
||||
return pflr_->GetOutputDevices(handle_, &output_devices_);
|
||||
return pflr_->GetOutputDevices(handle_, &output_devices_,
|
||||
ctx.eager_lazy_copy);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -97,6 +97,7 @@ class KernelAndDevice : public core::RefCounted {
|
||||
public:
|
||||
struct Context {
|
||||
bool log_device_placement = false;
|
||||
bool eager_lazy_copy = false;
|
||||
};
|
||||
|
||||
// Populates this with a kernel appropriate for 'ndef'.
|
||||
|
@ -466,18 +466,6 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
|
||||
<< " src_device: " << *src_device
|
||||
<< " colo group: " << colocation_group;
|
||||
}
|
||||
// If colocation_group is not set and output producing node is assigned
|
||||
// to a remote device, colocate the retval node with its input node.
|
||||
// TODO(yujingzhang): Remove this when we support outputting tensors on
|
||||
// remote devices.
|
||||
const bool remote_src_device =
|
||||
!src_device->empty() && GetFLR(*src_device) == nullptr;
|
||||
if (colocation_group.empty() && remote_src_device) {
|
||||
colocation_group =
|
||||
absl::StrCat(kColocationGroupPrefix, it->src()->name());
|
||||
VLOG(3) << "Considering src: " << src_node->name()
|
||||
<< " colo group: " << colocation_group;
|
||||
}
|
||||
|
||||
// If resource is produced by a function call node, we can't trust
|
||||
// source node device assignment, because multi-device functions can
|
||||
@ -510,6 +498,20 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
|
||||
"Unable to find any devices for spec ", *src_device);
|
||||
}
|
||||
} else if (matching_devices.size() != 1) {
|
||||
bool on_same_task = true;
|
||||
for (int i = 1; i < matching_devices.size(); ++i) {
|
||||
if (!DeviceNameUtils::IsSameAddressSpace(
|
||||
matching_devices.at(0)->parsed_name(),
|
||||
matching_devices.at(i)->parsed_name())) {
|
||||
on_same_task = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If the src node of an output is assigned to a address space (e.g.
|
||||
// py_func), rely on placer to assign a device to the output.
|
||||
if (on_same_task) {
|
||||
continue;
|
||||
}
|
||||
// Convert a vector of devices to a string.
|
||||
// Using absl::StrJoin did not work in Android builds.
|
||||
string devices = "[";
|
||||
@ -523,6 +525,7 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
|
||||
devices.append("]");
|
||||
|
||||
return errors::InvalidArgument(
|
||||
*src_device,
|
||||
"When FunctionLibraryRuntime::Options.output_devices are "
|
||||
"not specified for a multi-device function, the device "
|
||||
"specification on the output node must match exactly one "
|
||||
@ -968,6 +971,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
|
||||
done(s);
|
||||
} else {
|
||||
opts.ret_indices = comp_data->ret_indices;
|
||||
// Initialize remote function asynchronously.
|
||||
InstantiateRemote(unique_name, attrs, opts, component_handle, done);
|
||||
}
|
||||
@ -988,9 +992,9 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
}
|
||||
|
||||
Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
||||
FunctionLibraryRuntime::Handle handle,
|
||||
std::vector<Device*>* output_devices) const {
|
||||
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
|
||||
FunctionLibraryRuntime::Handle handle, std::vector<Device*>* output_devices,
|
||||
const bool eager_lazy_copy) const {
|
||||
MultiDeviceFunctionData* data = IsMultiDevice(handle);
|
||||
if (data == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Failed for find multi-device function handle ", handle);
|
||||
@ -1008,6 +1012,19 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
||||
Device* target_device = nullptr;
|
||||
Device* host = nullptr;
|
||||
if (target_flr == nullptr) {
|
||||
if (!eager_lazy_copy) {
|
||||
return errors::Unimplemented(
|
||||
"Currently, outputting tensors on remote devices is not supported."
|
||||
"The ",
|
||||
comp_data.ret_indices[0],
|
||||
"-th return value of the function outputs to target_device: ",
|
||||
target,
|
||||
" Please copy the tensor to local device explicitly using "
|
||||
"tf.identity and return the new Tensor instead.");
|
||||
}
|
||||
if (!data->has_remote_outputs) {
|
||||
data->has_remote_outputs = true;
|
||||
}
|
||||
target_device = device_set()->FindDeviceByName(target);
|
||||
string remote_host;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -1607,7 +1624,12 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
|
||||
std::vector<FunctionRet>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done) const {
|
||||
if (!args.HasRemoteOrPackedInputs()) {
|
||||
bool has_remote_outputs = false;
|
||||
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
|
||||
if (data != nullptr) {
|
||||
has_remote_outputs = data->has_remote_outputs;
|
||||
}
|
||||
if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) {
|
||||
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
|
||||
std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
|
||||
return Run(
|
||||
|
@ -151,7 +151,8 @@ class ProcessFunctionLibraryRuntime {
|
||||
// is set to the device backing the resource.
|
||||
// REQUIRES: `handle` identifies a multi-device function.
|
||||
Status GetOutputDevices(FunctionLibraryRuntime::Handle handle,
|
||||
std::vector<Device*>* output_devices) const;
|
||||
std::vector<Device*>* output_devices,
|
||||
const bool eager_lazy_copy) const;
|
||||
|
||||
// Returns true if function with handle `handle` was instantiated on device
|
||||
// `device_name`. Returns false for multi-device functions.
|
||||
@ -271,7 +272,8 @@ class ProcessFunctionLibraryRuntime {
|
||||
lib_def_(std::move(lib_def)),
|
||||
num_outputs_(num_outputs),
|
||||
ret_types_(std::move(ret_types)),
|
||||
is_cross_process_(false) {}
|
||||
is_cross_process_(false),
|
||||
has_remote_outputs(false) {}
|
||||
|
||||
const string function_name_;
|
||||
const string function_key_;
|
||||
@ -285,6 +287,8 @@ class ProcessFunctionLibraryRuntime {
|
||||
|
||||
// Indicates whether this function needs to execute cross process.
|
||||
bool is_cross_process_;
|
||||
// Indicates whether this function has remote outputs.
|
||||
bool has_remote_outputs;
|
||||
|
||||
// Maps the device name to the information about the component function
|
||||
// be run on this device.
|
||||
|
@ -105,6 +105,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
|
@ -44,6 +44,7 @@ cc_library(
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/distributed_runtime:call_options",
|
||||
"//tensorflow/core/distributed_runtime:worker_session",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
|
@ -96,14 +96,16 @@ void EagerClusterFunctionLibraryRuntime::Instantiate(
|
||||
.ToProto();
|
||||
StripDefaultAttributesInRegisterFunctionOp(register_function);
|
||||
|
||||
const absl::optional<std::vector<int>>& ret_indices = options.ret_indices;
|
||||
eager_client->EnqueueAsync(
|
||||
/*call_opts=*/nullptr, request.get(), response.get(),
|
||||
[this, request, response, handle, released_op = released_op.release(),
|
||||
target, eager_client = eager_client.get(), done](const Status& s) {
|
||||
target, ret_indices, eager_client = eager_client.get(),
|
||||
done](const Status& s) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
*handle = function_data_.size();
|
||||
function_data_.emplace_back(target, eager_client,
|
||||
function_data_.emplace_back(target, ret_indices, eager_client,
|
||||
absl::WrapUnique(released_op));
|
||||
}
|
||||
done(s);
|
||||
@ -168,6 +170,12 @@ void EagerClusterFunctionLibraryRuntime::Run(
|
||||
request->set_context_id(context_id_);
|
||||
eager::Operation* remote_op = request->mutable_operation();
|
||||
|
||||
if (function_data->ret_indices.has_value()) {
|
||||
for (const int ret_index : function_data->ret_indices.value()) {
|
||||
request->add_output_num(ret_index);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& arg : args) {
|
||||
if (arg.index() == 0) {
|
||||
absl::get<Tensor>(arg).AsProtoTensorContent(
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
@ -84,12 +85,15 @@ class EagerClusterFunctionLibraryRuntime
|
||||
|
||||
struct FunctionData {
|
||||
const string target;
|
||||
const absl::optional<std::vector<int>> ret_indices;
|
||||
core::RefCountPtr<EagerClient> eager_client;
|
||||
std::unique_ptr<EagerOperation> op;
|
||||
|
||||
FunctionData(const string& target, EagerClient* eager_client,
|
||||
std::unique_ptr<EagerOperation> op)
|
||||
FunctionData(const string& target,
|
||||
const absl::optional<std::vector<int>>& ret_indices,
|
||||
EagerClient* eager_client, std::unique_ptr<EagerOperation> op)
|
||||
: target(target),
|
||||
ret_indices(ret_indices),
|
||||
eager_client(core::RefCountPtr<EagerClient>(eager_client)),
|
||||
op(std::move(op)) {
|
||||
eager_client->Ref();
|
||||
|
@ -171,7 +171,8 @@ Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
|
||||
|
||||
Status AddOpRetvalsToResponse(
|
||||
EagerContext* eager_context, int op_id, int num_retvals,
|
||||
TensorHandle** retvals, std::function<TensorProto*()> add_tensor_proto_fn,
|
||||
const std::vector<int32>& output_nums, TensorHandle** retvals,
|
||||
std::function<TensorProto*()> add_tensor_proto_fn,
|
||||
std::function<TensorShapeProto*()> add_shape_proto_fn,
|
||||
std::function<string*()> add_device_fn = nullptr) {
|
||||
if (op_id == kInvalidRemoteOpId) {
|
||||
@ -195,7 +196,9 @@ Status AddOpRetvalsToResponse(
|
||||
if (is_remote) {
|
||||
retvals[i]->Unref();
|
||||
} else {
|
||||
eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id, i);
|
||||
const int output_num = output_nums.empty() ? i : output_nums.at(i);
|
||||
eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id,
|
||||
output_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -474,6 +477,10 @@ void EagerServiceImpl::RunComponentFunction(
|
||||
auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
|
||||
VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
|
||||
<< operation.id();
|
||||
std::vector<int32> output_nums;
|
||||
for (const int32 output_num : request->output_num()) {
|
||||
output_nums.push_back(output_num);
|
||||
}
|
||||
|
||||
auto cm = std::make_shared<CancellationManager>();
|
||||
op->SetCancellationManager(cm.get());
|
||||
@ -482,8 +489,8 @@ void EagerServiceImpl::RunComponentFunction(
|
||||
context->Ref();
|
||||
EagerLocalExecuteAsync(
|
||||
op, retvals->data(), num_retvals,
|
||||
[op, op_id = operation.id(), num_retvals, retvals, cm, call_opts,
|
||||
response, eager_context, context,
|
||||
[op, op_id = operation.id(), num_retvals, retvals, output_nums, cm,
|
||||
call_opts, response, eager_context, context,
|
||||
done = std::move(done)](const Status& status) {
|
||||
call_opts->ClearCancelCallback();
|
||||
auto wrapped_done = [&](const Status& status) {
|
||||
@ -500,7 +507,7 @@ void EagerServiceImpl::RunComponentFunction(
|
||||
// The output device of a component function is the component device
|
||||
// which is known on the default device of it's parent function.
|
||||
wrapped_done(AddOpRetvalsToResponse(
|
||||
eager_context, op_id, *num_retvals, retvals->data(),
|
||||
eager_context, op_id, *num_retvals, output_nums, retvals->data(),
|
||||
[response] { return response->add_tensor(); },
|
||||
[response] { return response->add_shape(); }));
|
||||
});
|
||||
@ -539,8 +546,8 @@ Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts,
|
||||
}
|
||||
|
||||
return AddOpRetvalsToResponse(
|
||||
eager_context, operation.id(), num_retvals, retvals.data(),
|
||||
[queue_response] { return queue_response->add_tensor(); },
|
||||
eager_context, operation.id(), num_retvals, /*output_nums=*/{},
|
||||
retvals.data(), [queue_response] { return queue_response->add_tensor(); },
|
||||
[queue_response] { return queue_response->add_shape(); },
|
||||
std::move(add_device_fn));
|
||||
}
|
||||
|
@ -224,10 +224,11 @@ void AddOperationToRunComponentFunctionRequest(
|
||||
const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
|
||||
inputs,
|
||||
const std::unordered_map<string, AttrValue>& attrs, const string& device,
|
||||
RunComponentFunctionRequest* request) {
|
||||
const int output_num, RunComponentFunctionRequest* request) {
|
||||
auto* operation = request->mutable_operation();
|
||||
operation->set_is_function(true);
|
||||
operation->set_is_component_function(true);
|
||||
request->add_output_num(output_num);
|
||||
BuildOperation(operation, id, name, inputs, attrs, device);
|
||||
}
|
||||
|
||||
@ -610,10 +611,12 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest {
|
||||
RunComponentFunctionRequest run_comp_func_request;
|
||||
run_comp_func_request.set_context_id(context_id);
|
||||
RunComponentFunctionResponse run_comp_func_response;
|
||||
const int output_num = 5;
|
||||
AddOperationToRunComponentFunctionRequest(
|
||||
2, function_name, {std::make_pair(1, 0)},
|
||||
std::unordered_map<string, AttrValue>(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", &run_comp_func_request);
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0", output_num,
|
||||
&run_comp_func_request);
|
||||
|
||||
CallOptions call_opts;
|
||||
Notification n;
|
||||
@ -636,7 +639,8 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* tensor_handle;
|
||||
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
||||
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
||||
context_id, RemoteTensorHandleInternal(2, output_num),
|
||||
&tensor_handle));
|
||||
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
||||
|
||||
auto actual = t->flat<float>();
|
||||
|
@ -612,6 +612,9 @@ class FunctionLibraryRuntime {
|
||||
// infer correct device.
|
||||
std::vector<string> output_devices;
|
||||
|
||||
// If set, it indicates the original output indices of a component function.
|
||||
absl::optional<std::vector<int>> ret_indices = absl::nullopt;
|
||||
|
||||
// Maps from a CompositeDevice name to a list of underlying physical
|
||||
// devices.
|
||||
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
|
||||
|
@ -180,6 +180,9 @@ message RunComponentFunctionRequest {
|
||||
fixed64 context_id = 1;
|
||||
|
||||
Operation operation = 2;
|
||||
|
||||
// The output indices of its parent function.
|
||||
repeated int32 output_num = 3;
|
||||
}
|
||||
|
||||
message RunComponentFunctionResponse {
|
||||
|
@ -92,7 +92,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testMultiDeviceFunctionRemoteOutput(self):
|
||||
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
||||
variable_b = variables.Variable(1)
|
||||
@ -101,10 +100,15 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
||||
def remote_output(i):
|
||||
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
||||
c = variable_b + 1
|
||||
return c, i + variable_b
|
||||
return i + variable_b, c
|
||||
|
||||
self.assertAllEqual(
|
||||
remote_output(constant_op.constant([1]))[0].numpy(), 2)
|
||||
rets = remote_output(constant_op.constant([1]))
|
||||
self.assertEqual(rets[0].backing_device,
|
||||
'/job:localhost/replica:0/task:0/device:CPU:0')
|
||||
self.assertEqual(rets[1].backing_device,
|
||||
'/job:worker/replica:0/task:0/device:CPU:0')
|
||||
self.assertAllEqual(rets[0].numpy(), [2])
|
||||
self.assertAllEqual(rets[1].numpy(), 2)
|
||||
|
||||
def testMultiDeviceFunctionAmbiguousDevice(self):
|
||||
|
||||
@ -482,6 +486,25 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
|
||||
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
||||
|
||||
def testMultiDeviceFunctionRemoteOutput(self):
|
||||
with ops.device('/job:worker/replica:0/task:1/cpu:0'):
|
||||
variable_b = variables.Variable(1)
|
||||
|
||||
@def_function.function
|
||||
def remote_output(i):
|
||||
with ops.device('/job:worker/replica:0/task:1/cpu:0'):
|
||||
c = variable_b + 1
|
||||
return i + variable_b, c
|
||||
|
||||
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
||||
rets = remote_output(constant_op.constant([1]))
|
||||
self.assertEqual(rets[0].backing_device,
|
||||
'/job:worker/replica:0/task:0/device:CPU:0')
|
||||
self.assertEqual(rets[1].backing_device,
|
||||
'/job:worker/replica:0/task:1/device:CPU:0')
|
||||
self.assertAllEqual(rets[0].numpy(), [2])
|
||||
self.assertAllEqual(rets[1].numpy(), 2)
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testMultiDeviceWhileLoopOnRemoteDevice(self):
|
||||
with ops.device('/job:worker/replica:0/task:1'):
|
||||
|
Loading…
Reference in New Issue
Block a user