Use the original output indices when adding a component function output to RemoteMgr.

PiperOrigin-RevId: 327507408
Change-Id: Ie33d8467aec3901340ac8edd8892f28811b92c2a
This commit is contained in:
Yujing Zhang 2020-08-19 14:30:25 -07:00 committed by TensorFlower Gardener
parent f06308bfc4
commit 699178a5d7
15 changed files with 123 additions and 39 deletions

View File

@ -551,8 +551,9 @@ Status GetOrCreateKernelAndDevice(
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
}
TF_RETURN_IF_ERROR(
kernel->Init({ctx.LogDevicePlacement()}, ndef, graph_collector));
TF_RETURN_IF_ERROR(kernel->Init(
{ctx.LogDevicePlacement(), ctx.LazyCopyFunctionRemoteInputs()}, ndef,
graph_collector));
if (op->is_function()) {
ctx.AddKernelToCache(cache_key, kernel.get());

View File

@ -223,7 +223,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
Status KernelAndDeviceFunc::Init(const Context& ctx, const NodeDef& ndef,
GraphCollector* graph_collector) {
TF_RETURN_IF_ERROR(InstantiateFunc(ctx, ndef, graph_collector));
return pflr_->GetOutputDevices(handle_, &output_devices_);
return pflr_->GetOutputDevices(handle_, &output_devices_,
ctx.eager_lazy_copy);
}
namespace {

View File

@ -97,6 +97,7 @@ class KernelAndDevice : public core::RefCounted {
public:
struct Context {
bool log_device_placement = false;
bool eager_lazy_copy = false;
};
// Populates this with a kernel appropriate for 'ndef'.

View File

@ -466,18 +466,6 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
<< " src_device: " << *src_device
<< " colo group: " << colocation_group;
}
// If colocation_group is not set and output producing node is assigned
// to a remote device, colocate the retval node with its input node.
// TODO(yujingzhang): Remove this when we support outputting tensors on
// remote devices.
const bool remote_src_device =
!src_device->empty() && GetFLR(*src_device) == nullptr;
if (colocation_group.empty() && remote_src_device) {
colocation_group =
absl::StrCat(kColocationGroupPrefix, it->src()->name());
VLOG(3) << "Considering src: " << src_node->name()
<< " colo group: " << colocation_group;
}
// If resource is produced by a function call node, we can't trust
// source node device assignment, because multi-device functions can
@ -510,6 +498,20 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
"Unable to find any devices for spec ", *src_device);
}
} else if (matching_devices.size() != 1) {
bool on_same_task = true;
for (int i = 1; i < matching_devices.size(); ++i) {
if (!DeviceNameUtils::IsSameAddressSpace(
matching_devices.at(0)->parsed_name(),
matching_devices.at(i)->parsed_name())) {
on_same_task = false;
break;
}
}
// If the src node of an output is assigned to a address space (e.g.
// py_func), rely on placer to assign a device to the output.
if (on_same_task) {
continue;
}
// Convert a vector of devices to a string.
// Using absl::StrJoin did not work in Android builds.
string devices = "[";
@ -523,6 +525,7 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
devices.append("]");
return errors::InvalidArgument(
*src_device,
"When FunctionLibraryRuntime::Options.output_devices are "
"not specified for a multi-device function, the device "
"specification on the output node must match exactly one "
@ -968,6 +971,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
done(s);
} else {
opts.ret_indices = comp_data->ret_indices;
// Initialize remote function asynchronously.
InstantiateRemote(unique_name, attrs, opts, component_handle, done);
}
@ -988,9 +992,9 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
}
Status ProcessFunctionLibraryRuntime::GetOutputDevices(
FunctionLibraryRuntime::Handle handle,
std::vector<Device*>* output_devices) const {
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
FunctionLibraryRuntime::Handle handle, std::vector<Device*>* output_devices,
const bool eager_lazy_copy) const {
MultiDeviceFunctionData* data = IsMultiDevice(handle);
if (data == nullptr) {
return errors::InvalidArgument(
"Failed for find multi-device function handle ", handle);
@ -1008,6 +1012,19 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
Device* target_device = nullptr;
Device* host = nullptr;
if (target_flr == nullptr) {
if (!eager_lazy_copy) {
return errors::Unimplemented(
"Currently, outputting tensors on remote devices is not supported."
"The ",
comp_data.ret_indices[0],
"-th return value of the function outputs to target_device: ",
target,
" Please copy the tensor to local device explicitly using "
"tf.identity and return the new Tensor instead.");
}
if (!data->has_remote_outputs) {
data->has_remote_outputs = true;
}
target_device = device_set()->FindDeviceByName(target);
string remote_host;
TF_RETURN_IF_ERROR(
@ -1607,7 +1624,12 @@ void ProcessFunctionLibraryRuntime::Run(
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
if (!args.HasRemoteOrPackedInputs()) {
bool has_remote_outputs = false;
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
if (data != nullptr) {
has_remote_outputs = data->has_remote_outputs;
}
if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) {
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
return Run(

View File

@ -151,7 +151,8 @@ class ProcessFunctionLibraryRuntime {
// is set to the device backing the resource.
// REQUIRES: `handle` identifies a multi-device function.
Status GetOutputDevices(FunctionLibraryRuntime::Handle handle,
std::vector<Device*>* output_devices) const;
std::vector<Device*>* output_devices,
const bool eager_lazy_copy) const;
// Returns true if function with handle `handle` was instantiated on device
// `device_name`. Returns false for multi-device functions.
@ -271,7 +272,8 @@ class ProcessFunctionLibraryRuntime {
lib_def_(std::move(lib_def)),
num_outputs_(num_outputs),
ret_types_(std::move(ret_types)),
is_cross_process_(false) {}
is_cross_process_(false),
has_remote_outputs(false) {}
const string function_name_;
const string function_key_;
@ -285,6 +287,8 @@ class ProcessFunctionLibraryRuntime {
// Indicates whether this function needs to execute cross process.
bool is_cross_process_;
// Indicates whether this function has remote outputs.
bool has_remote_outputs;
// Maps the device name to the information about the component function
// be run on this device.

View File

@ -105,6 +105,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
#include "absl/types/optional.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/function.h"

View File

@ -44,6 +44,7 @@ cc_library(
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:worker_session",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
],

View File

@ -96,14 +96,16 @@ void EagerClusterFunctionLibraryRuntime::Instantiate(
.ToProto();
StripDefaultAttributesInRegisterFunctionOp(register_function);
const absl::optional<std::vector<int>>& ret_indices = options.ret_indices;
eager_client->EnqueueAsync(
/*call_opts=*/nullptr, request.get(), response.get(),
[this, request, response, handle, released_op = released_op.release(),
target, eager_client = eager_client.get(), done](const Status& s) {
target, ret_indices, eager_client = eager_client.get(),
done](const Status& s) {
{
mutex_lock l(mu_);
*handle = function_data_.size();
function_data_.emplace_back(target, eager_client,
function_data_.emplace_back(target, ret_indices, eager_client,
absl::WrapUnique(released_op));
}
done(s);
@ -168,6 +170,12 @@ void EagerClusterFunctionLibraryRuntime::Run(
request->set_context_id(context_id_);
eager::Operation* remote_op = request->mutable_operation();
if (function_data->ret_indices.has_value()) {
for (const int ret_index : function_data->ret_indices.value()) {
request->add_output_num(ret_index);
}
}
for (const auto& arg : args) {
if (arg.index() == 0) {
absl::get<Tensor>(arg).AsProtoTensorContent(

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
#include "absl/types/optional.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
@ -84,12 +85,15 @@ class EagerClusterFunctionLibraryRuntime
struct FunctionData {
const string target;
const absl::optional<std::vector<int>> ret_indices;
core::RefCountPtr<EagerClient> eager_client;
std::unique_ptr<EagerOperation> op;
FunctionData(const string& target, EagerClient* eager_client,
std::unique_ptr<EagerOperation> op)
FunctionData(const string& target,
const absl::optional<std::vector<int>>& ret_indices,
EagerClient* eager_client, std::unique_ptr<EagerOperation> op)
: target(target),
ret_indices(ret_indices),
eager_client(core::RefCountPtr<EagerClient>(eager_client)),
op(std::move(op)) {
eager_client->Ref();

View File

@ -171,7 +171,8 @@ Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
Status AddOpRetvalsToResponse(
EagerContext* eager_context, int op_id, int num_retvals,
TensorHandle** retvals, std::function<TensorProto*()> add_tensor_proto_fn,
const std::vector<int32>& output_nums, TensorHandle** retvals,
std::function<TensorProto*()> add_tensor_proto_fn,
std::function<TensorShapeProto*()> add_shape_proto_fn,
std::function<string*()> add_device_fn = nullptr) {
if (op_id == kInvalidRemoteOpId) {
@ -195,7 +196,9 @@ Status AddOpRetvalsToResponse(
if (is_remote) {
retvals[i]->Unref();
} else {
eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id, i);
const int output_num = output_nums.empty() ? i : output_nums.at(i);
eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id,
output_num);
}
}
}
@ -474,6 +477,10 @@ void EagerServiceImpl::RunComponentFunction(
auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
<< operation.id();
std::vector<int32> output_nums;
for (const int32 output_num : request->output_num()) {
output_nums.push_back(output_num);
}
auto cm = std::make_shared<CancellationManager>();
op->SetCancellationManager(cm.get());
@ -482,8 +489,8 @@ void EagerServiceImpl::RunComponentFunction(
context->Ref();
EagerLocalExecuteAsync(
op, retvals->data(), num_retvals,
[op, op_id = operation.id(), num_retvals, retvals, cm, call_opts,
response, eager_context, context,
[op, op_id = operation.id(), num_retvals, retvals, output_nums, cm,
call_opts, response, eager_context, context,
done = std::move(done)](const Status& status) {
call_opts->ClearCancelCallback();
auto wrapped_done = [&](const Status& status) {
@ -500,7 +507,7 @@ void EagerServiceImpl::RunComponentFunction(
// The output device of a component function is the component device
// which is known on the default device of it's parent function.
wrapped_done(AddOpRetvalsToResponse(
eager_context, op_id, *num_retvals, retvals->data(),
eager_context, op_id, *num_retvals, output_nums, retvals->data(),
[response] { return response->add_tensor(); },
[response] { return response->add_shape(); }));
});
@ -539,8 +546,8 @@ Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts,
}
return AddOpRetvalsToResponse(
eager_context, operation.id(), num_retvals, retvals.data(),
[queue_response] { return queue_response->add_tensor(); },
eager_context, operation.id(), num_retvals, /*output_nums=*/{},
retvals.data(), [queue_response] { return queue_response->add_tensor(); },
[queue_response] { return queue_response->add_shape(); },
std::move(add_device_fn));
}

View File

@ -224,10 +224,11 @@ void AddOperationToRunComponentFunctionRequest(
const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
inputs,
const std::unordered_map<string, AttrValue>& attrs, const string& device,
RunComponentFunctionRequest* request) {
const int output_num, RunComponentFunctionRequest* request) {
auto* operation = request->mutable_operation();
operation->set_is_function(true);
operation->set_is_component_function(true);
request->add_output_num(output_num);
BuildOperation(operation, id, name, inputs, attrs, device);
}
@ -610,10 +611,12 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest {
RunComponentFunctionRequest run_comp_func_request;
run_comp_func_request.set_context_id(context_id);
RunComponentFunctionResponse run_comp_func_response;
const int output_num = 5;
AddOperationToRunComponentFunctionRequest(
2, function_name, {std::make_pair(1, 0)},
std::unordered_map<string, AttrValue>(),
"/job:localhost/replica:0/task:0/device:CPU:0", &run_comp_func_request);
"/job:localhost/replica:0/task:0/device:CPU:0", output_num,
&run_comp_func_request);
CallOptions call_opts;
Notification n;
@ -636,7 +639,8 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* tensor_handle;
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
context_id, RemoteTensorHandleInternal(2, output_num),
&tensor_handle));
TF_ASSERT_OK(tensor_handle->Tensor(&t));
auto actual = t->flat<float>();

View File

@ -612,6 +612,9 @@ class FunctionLibraryRuntime {
// infer correct device.
std::vector<string> output_devices;
// If set, it indicates the original output indices of a component function.
absl::optional<std::vector<int>> ret_indices = absl::nullopt;
// Maps from a CompositeDevice name to a list of underlying physical
// devices.
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;

View File

@ -180,6 +180,9 @@ message RunComponentFunctionRequest {
fixed64 context_id = 1;
Operation operation = 2;
// The output indices of its parent function.
repeated int32 output_num = 3;
}
message RunComponentFunctionResponse {

View File

@ -92,7 +92,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])
@test_util.eager_lazy_remote_copy_on_and_off
def testMultiDeviceFunctionRemoteOutput(self):
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
variable_b = variables.Variable(1)
@ -101,10 +100,15 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
def remote_output(i):
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
c = variable_b + 1
return c, i + variable_b
return i + variable_b, c
self.assertAllEqual(
remote_output(constant_op.constant([1]))[0].numpy(), 2)
rets = remote_output(constant_op.constant([1]))
self.assertEqual(rets[0].backing_device,
'/job:localhost/replica:0/task:0/device:CPU:0')
self.assertEqual(rets[1].backing_device,
'/job:worker/replica:0/task:0/device:CPU:0')
self.assertAllEqual(rets[0].numpy(), [2])
self.assertAllEqual(rets[1].numpy(), 2)
def testMultiDeviceFunctionAmbiguousDevice(self):
@ -482,6 +486,25 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
def testMultiDeviceFunctionRemoteOutput(self):
with ops.device('/job:worker/replica:0/task:1/cpu:0'):
variable_b = variables.Variable(1)
@def_function.function
def remote_output(i):
with ops.device('/job:worker/replica:0/task:1/cpu:0'):
c = variable_b + 1
return i + variable_b, c
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
rets = remote_output(constant_op.constant([1]))
self.assertEqual(rets[0].backing_device,
'/job:worker/replica:0/task:0/device:CPU:0')
self.assertEqual(rets[1].backing_device,
'/job:worker/replica:0/task:1/device:CPU:0')
self.assertAllEqual(rets[0].numpy(), [2])
self.assertAllEqual(rets[1].numpy(), 2)
@test_util.eager_lazy_remote_copy_on_and_off
def testMultiDeviceWhileLoopOnRemoteDevice(self):
with ops.device('/job:worker/replica:0/task:1'):