Use the original output indices when adding a component function output to RemoteMgr.
PiperOrigin-RevId: 327507408 Change-Id: Ie33d8467aec3901340ac8edd8892f28811b92c2a
This commit is contained in:
parent
f06308bfc4
commit
699178a5d7
@ -551,8 +551,9 @@ Status GetOrCreateKernelAndDevice(
|
|||||||
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
|
ctx.GetCollectiveExecutorHandle(), ctx.HostCPU()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(kernel->Init(
|
||||||
kernel->Init({ctx.LogDevicePlacement()}, ndef, graph_collector));
|
{ctx.LogDevicePlacement(), ctx.LazyCopyFunctionRemoteInputs()}, ndef,
|
||||||
|
graph_collector));
|
||||||
|
|
||||||
if (op->is_function()) {
|
if (op->is_function()) {
|
||||||
ctx.AddKernelToCache(cache_key, kernel.get());
|
ctx.AddKernelToCache(cache_key, kernel.get());
|
||||||
|
@ -223,7 +223,8 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx,
|
|||||||
Status KernelAndDeviceFunc::Init(const Context& ctx, const NodeDef& ndef,
|
Status KernelAndDeviceFunc::Init(const Context& ctx, const NodeDef& ndef,
|
||||||
GraphCollector* graph_collector) {
|
GraphCollector* graph_collector) {
|
||||||
TF_RETURN_IF_ERROR(InstantiateFunc(ctx, ndef, graph_collector));
|
TF_RETURN_IF_ERROR(InstantiateFunc(ctx, ndef, graph_collector));
|
||||||
return pflr_->GetOutputDevices(handle_, &output_devices_);
|
return pflr_->GetOutputDevices(handle_, &output_devices_,
|
||||||
|
ctx.eager_lazy_copy);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -97,6 +97,7 @@ class KernelAndDevice : public core::RefCounted {
|
|||||||
public:
|
public:
|
||||||
struct Context {
|
struct Context {
|
||||||
bool log_device_placement = false;
|
bool log_device_placement = false;
|
||||||
|
bool eager_lazy_copy = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Populates this with a kernel appropriate for 'ndef'.
|
// Populates this with a kernel appropriate for 'ndef'.
|
||||||
|
@ -466,18 +466,6 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
|
|||||||
<< " src_device: " << *src_device
|
<< " src_device: " << *src_device
|
||||||
<< " colo group: " << colocation_group;
|
<< " colo group: " << colocation_group;
|
||||||
}
|
}
|
||||||
// If colocation_group is not set and output producing node is assigned
|
|
||||||
// to a remote device, colocate the retval node with its input node.
|
|
||||||
// TODO(yujingzhang): Remove this when we support outputting tensors on
|
|
||||||
// remote devices.
|
|
||||||
const bool remote_src_device =
|
|
||||||
!src_device->empty() && GetFLR(*src_device) == nullptr;
|
|
||||||
if (colocation_group.empty() && remote_src_device) {
|
|
||||||
colocation_group =
|
|
||||||
absl::StrCat(kColocationGroupPrefix, it->src()->name());
|
|
||||||
VLOG(3) << "Considering src: " << src_node->name()
|
|
||||||
<< " colo group: " << colocation_group;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If resource is produced by a function call node, we can't trust
|
// If resource is produced by a function call node, we can't trust
|
||||||
// source node device assignment, because multi-device functions can
|
// source node device assignment, because multi-device functions can
|
||||||
@ -510,6 +498,20 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
|
|||||||
"Unable to find any devices for spec ", *src_device);
|
"Unable to find any devices for spec ", *src_device);
|
||||||
}
|
}
|
||||||
} else if (matching_devices.size() != 1) {
|
} else if (matching_devices.size() != 1) {
|
||||||
|
bool on_same_task = true;
|
||||||
|
for (int i = 1; i < matching_devices.size(); ++i) {
|
||||||
|
if (!DeviceNameUtils::IsSameAddressSpace(
|
||||||
|
matching_devices.at(0)->parsed_name(),
|
||||||
|
matching_devices.at(i)->parsed_name())) {
|
||||||
|
on_same_task = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If the src node of an output is assigned to a address space (e.g.
|
||||||
|
// py_func), rely on placer to assign a device to the output.
|
||||||
|
if (on_same_task) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
// Convert a vector of devices to a string.
|
// Convert a vector of devices to a string.
|
||||||
// Using absl::StrJoin did not work in Android builds.
|
// Using absl::StrJoin did not work in Android builds.
|
||||||
string devices = "[";
|
string devices = "[";
|
||||||
@ -523,6 +525,7 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
|
|||||||
devices.append("]");
|
devices.append("]");
|
||||||
|
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
|
*src_device,
|
||||||
"When FunctionLibraryRuntime::Options.output_devices are "
|
"When FunctionLibraryRuntime::Options.output_devices are "
|
||||||
"not specified for a multi-device function, the device "
|
"not specified for a multi-device function, the device "
|
||||||
"specification on the output node must match exactly one "
|
"specification on the output node must match exactly one "
|
||||||
@ -968,6 +971,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||||||
Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
|
Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
|
||||||
done(s);
|
done(s);
|
||||||
} else {
|
} else {
|
||||||
|
opts.ret_indices = comp_data->ret_indices;
|
||||||
// Initialize remote function asynchronously.
|
// Initialize remote function asynchronously.
|
||||||
InstantiateRemote(unique_name, attrs, opts, component_handle, done);
|
InstantiateRemote(unique_name, attrs, opts, component_handle, done);
|
||||||
}
|
}
|
||||||
@ -988,9 +992,9 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
||||||
FunctionLibraryRuntime::Handle handle,
|
FunctionLibraryRuntime::Handle handle, std::vector<Device*>* output_devices,
|
||||||
std::vector<Device*>* output_devices) const {
|
const bool eager_lazy_copy) const {
|
||||||
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
|
MultiDeviceFunctionData* data = IsMultiDevice(handle);
|
||||||
if (data == nullptr) {
|
if (data == nullptr) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Failed for find multi-device function handle ", handle);
|
"Failed for find multi-device function handle ", handle);
|
||||||
@ -1008,6 +1012,19 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
|||||||
Device* target_device = nullptr;
|
Device* target_device = nullptr;
|
||||||
Device* host = nullptr;
|
Device* host = nullptr;
|
||||||
if (target_flr == nullptr) {
|
if (target_flr == nullptr) {
|
||||||
|
if (!eager_lazy_copy) {
|
||||||
|
return errors::Unimplemented(
|
||||||
|
"Currently, outputting tensors on remote devices is not supported."
|
||||||
|
"The ",
|
||||||
|
comp_data.ret_indices[0],
|
||||||
|
"-th return value of the function outputs to target_device: ",
|
||||||
|
target,
|
||||||
|
" Please copy the tensor to local device explicitly using "
|
||||||
|
"tf.identity and return the new Tensor instead.");
|
||||||
|
}
|
||||||
|
if (!data->has_remote_outputs) {
|
||||||
|
data->has_remote_outputs = true;
|
||||||
|
}
|
||||||
target_device = device_set()->FindDeviceByName(target);
|
target_device = device_set()->FindDeviceByName(target);
|
||||||
string remote_host;
|
string remote_host;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -1607,7 +1624,12 @@ void ProcessFunctionLibraryRuntime::Run(
|
|||||||
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
|
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
|
||||||
std::vector<FunctionRet>* rets,
|
std::vector<FunctionRet>* rets,
|
||||||
FunctionLibraryRuntime::DoneCallback done) const {
|
FunctionLibraryRuntime::DoneCallback done) const {
|
||||||
if (!args.HasRemoteOrPackedInputs()) {
|
bool has_remote_outputs = false;
|
||||||
|
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
|
||||||
|
if (data != nullptr) {
|
||||||
|
has_remote_outputs = data->has_remote_outputs;
|
||||||
|
}
|
||||||
|
if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) {
|
||||||
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
|
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
|
||||||
std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
|
std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
|
||||||
return Run(
|
return Run(
|
||||||
|
@ -151,7 +151,8 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
// is set to the device backing the resource.
|
// is set to the device backing the resource.
|
||||||
// REQUIRES: `handle` identifies a multi-device function.
|
// REQUIRES: `handle` identifies a multi-device function.
|
||||||
Status GetOutputDevices(FunctionLibraryRuntime::Handle handle,
|
Status GetOutputDevices(FunctionLibraryRuntime::Handle handle,
|
||||||
std::vector<Device*>* output_devices) const;
|
std::vector<Device*>* output_devices,
|
||||||
|
const bool eager_lazy_copy) const;
|
||||||
|
|
||||||
// Returns true if function with handle `handle` was instantiated on device
|
// Returns true if function with handle `handle` was instantiated on device
|
||||||
// `device_name`. Returns false for multi-device functions.
|
// `device_name`. Returns false for multi-device functions.
|
||||||
@ -271,7 +272,8 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
lib_def_(std::move(lib_def)),
|
lib_def_(std::move(lib_def)),
|
||||||
num_outputs_(num_outputs),
|
num_outputs_(num_outputs),
|
||||||
ret_types_(std::move(ret_types)),
|
ret_types_(std::move(ret_types)),
|
||||||
is_cross_process_(false) {}
|
is_cross_process_(false),
|
||||||
|
has_remote_outputs(false) {}
|
||||||
|
|
||||||
const string function_name_;
|
const string function_name_;
|
||||||
const string function_key_;
|
const string function_key_;
|
||||||
@ -285,6 +287,8 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
|
|
||||||
// Indicates whether this function needs to execute cross process.
|
// Indicates whether this function needs to execute cross process.
|
||||||
bool is_cross_process_;
|
bool is_cross_process_;
|
||||||
|
// Indicates whether this function has remote outputs.
|
||||||
|
bool has_remote_outputs;
|
||||||
|
|
||||||
// Maps the device name to the information about the component function
|
// Maps the device name to the information about the component function
|
||||||
// be run on this device.
|
// be run on this device.
|
||||||
|
@ -105,6 +105,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
@ -44,6 +44,7 @@ cc_library(
|
|||||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
"//tensorflow/core/distributed_runtime:call_options",
|
"//tensorflow/core/distributed_runtime:call_options",
|
||||||
"//tensorflow/core/distributed_runtime:worker_session",
|
"//tensorflow/core/distributed_runtime:worker_session",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"@com_google_absl//absl/types:variant",
|
"@com_google_absl//absl/types:variant",
|
||||||
],
|
],
|
||||||
|
@ -96,14 +96,16 @@ void EagerClusterFunctionLibraryRuntime::Instantiate(
|
|||||||
.ToProto();
|
.ToProto();
|
||||||
StripDefaultAttributesInRegisterFunctionOp(register_function);
|
StripDefaultAttributesInRegisterFunctionOp(register_function);
|
||||||
|
|
||||||
|
const absl::optional<std::vector<int>>& ret_indices = options.ret_indices;
|
||||||
eager_client->EnqueueAsync(
|
eager_client->EnqueueAsync(
|
||||||
/*call_opts=*/nullptr, request.get(), response.get(),
|
/*call_opts=*/nullptr, request.get(), response.get(),
|
||||||
[this, request, response, handle, released_op = released_op.release(),
|
[this, request, response, handle, released_op = released_op.release(),
|
||||||
target, eager_client = eager_client.get(), done](const Status& s) {
|
target, ret_indices, eager_client = eager_client.get(),
|
||||||
|
done](const Status& s) {
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
*handle = function_data_.size();
|
*handle = function_data_.size();
|
||||||
function_data_.emplace_back(target, eager_client,
|
function_data_.emplace_back(target, ret_indices, eager_client,
|
||||||
absl::WrapUnique(released_op));
|
absl::WrapUnique(released_op));
|
||||||
}
|
}
|
||||||
done(s);
|
done(s);
|
||||||
@ -168,6 +170,12 @@ void EagerClusterFunctionLibraryRuntime::Run(
|
|||||||
request->set_context_id(context_id_);
|
request->set_context_id(context_id_);
|
||||||
eager::Operation* remote_op = request->mutable_operation();
|
eager::Operation* remote_op = request->mutable_operation();
|
||||||
|
|
||||||
|
if (function_data->ret_indices.has_value()) {
|
||||||
|
for (const int ret_index : function_data->ret_indices.value()) {
|
||||||
|
request->add_output_num(ret_index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto& arg : args) {
|
for (const auto& arg : args) {
|
||||||
if (arg.index() == 0) {
|
if (arg.index() == 0) {
|
||||||
absl::get<Tensor>(arg).AsProtoTensorContent(
|
absl::get<Tensor>(arg).AsProtoTensorContent(
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
@ -84,12 +85,15 @@ class EagerClusterFunctionLibraryRuntime
|
|||||||
|
|
||||||
struct FunctionData {
|
struct FunctionData {
|
||||||
const string target;
|
const string target;
|
||||||
|
const absl::optional<std::vector<int>> ret_indices;
|
||||||
core::RefCountPtr<EagerClient> eager_client;
|
core::RefCountPtr<EagerClient> eager_client;
|
||||||
std::unique_ptr<EagerOperation> op;
|
std::unique_ptr<EagerOperation> op;
|
||||||
|
|
||||||
FunctionData(const string& target, EagerClient* eager_client,
|
FunctionData(const string& target,
|
||||||
std::unique_ptr<EagerOperation> op)
|
const absl::optional<std::vector<int>>& ret_indices,
|
||||||
|
EagerClient* eager_client, std::unique_ptr<EagerOperation> op)
|
||||||
: target(target),
|
: target(target),
|
||||||
|
ret_indices(ret_indices),
|
||||||
eager_client(core::RefCountPtr<EagerClient>(eager_client)),
|
eager_client(core::RefCountPtr<EagerClient>(eager_client)),
|
||||||
op(std::move(op)) {
|
op(std::move(op)) {
|
||||||
eager_client->Ref();
|
eager_client->Ref();
|
||||||
|
@ -171,7 +171,8 @@ Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
|
|||||||
|
|
||||||
Status AddOpRetvalsToResponse(
|
Status AddOpRetvalsToResponse(
|
||||||
EagerContext* eager_context, int op_id, int num_retvals,
|
EagerContext* eager_context, int op_id, int num_retvals,
|
||||||
TensorHandle** retvals, std::function<TensorProto*()> add_tensor_proto_fn,
|
const std::vector<int32>& output_nums, TensorHandle** retvals,
|
||||||
|
std::function<TensorProto*()> add_tensor_proto_fn,
|
||||||
std::function<TensorShapeProto*()> add_shape_proto_fn,
|
std::function<TensorShapeProto*()> add_shape_proto_fn,
|
||||||
std::function<string*()> add_device_fn = nullptr) {
|
std::function<string*()> add_device_fn = nullptr) {
|
||||||
if (op_id == kInvalidRemoteOpId) {
|
if (op_id == kInvalidRemoteOpId) {
|
||||||
@ -195,7 +196,9 @@ Status AddOpRetvalsToResponse(
|
|||||||
if (is_remote) {
|
if (is_remote) {
|
||||||
retvals[i]->Unref();
|
retvals[i]->Unref();
|
||||||
} else {
|
} else {
|
||||||
eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id, i);
|
const int output_num = output_nums.empty() ? i : output_nums.at(i);
|
||||||
|
eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id,
|
||||||
|
output_num);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -474,6 +477,10 @@ void EagerServiceImpl::RunComponentFunction(
|
|||||||
auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
|
auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
|
||||||
VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
|
VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
|
||||||
<< operation.id();
|
<< operation.id();
|
||||||
|
std::vector<int32> output_nums;
|
||||||
|
for (const int32 output_num : request->output_num()) {
|
||||||
|
output_nums.push_back(output_num);
|
||||||
|
}
|
||||||
|
|
||||||
auto cm = std::make_shared<CancellationManager>();
|
auto cm = std::make_shared<CancellationManager>();
|
||||||
op->SetCancellationManager(cm.get());
|
op->SetCancellationManager(cm.get());
|
||||||
@ -482,8 +489,8 @@ void EagerServiceImpl::RunComponentFunction(
|
|||||||
context->Ref();
|
context->Ref();
|
||||||
EagerLocalExecuteAsync(
|
EagerLocalExecuteAsync(
|
||||||
op, retvals->data(), num_retvals,
|
op, retvals->data(), num_retvals,
|
||||||
[op, op_id = operation.id(), num_retvals, retvals, cm, call_opts,
|
[op, op_id = operation.id(), num_retvals, retvals, output_nums, cm,
|
||||||
response, eager_context, context,
|
call_opts, response, eager_context, context,
|
||||||
done = std::move(done)](const Status& status) {
|
done = std::move(done)](const Status& status) {
|
||||||
call_opts->ClearCancelCallback();
|
call_opts->ClearCancelCallback();
|
||||||
auto wrapped_done = [&](const Status& status) {
|
auto wrapped_done = [&](const Status& status) {
|
||||||
@ -500,7 +507,7 @@ void EagerServiceImpl::RunComponentFunction(
|
|||||||
// The output device of a component function is the component device
|
// The output device of a component function is the component device
|
||||||
// which is known on the default device of it's parent function.
|
// which is known on the default device of it's parent function.
|
||||||
wrapped_done(AddOpRetvalsToResponse(
|
wrapped_done(AddOpRetvalsToResponse(
|
||||||
eager_context, op_id, *num_retvals, retvals->data(),
|
eager_context, op_id, *num_retvals, output_nums, retvals->data(),
|
||||||
[response] { return response->add_tensor(); },
|
[response] { return response->add_tensor(); },
|
||||||
[response] { return response->add_shape(); }));
|
[response] { return response->add_shape(); }));
|
||||||
});
|
});
|
||||||
@ -539,8 +546,8 @@ Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts,
|
|||||||
}
|
}
|
||||||
|
|
||||||
return AddOpRetvalsToResponse(
|
return AddOpRetvalsToResponse(
|
||||||
eager_context, operation.id(), num_retvals, retvals.data(),
|
eager_context, operation.id(), num_retvals, /*output_nums=*/{},
|
||||||
[queue_response] { return queue_response->add_tensor(); },
|
retvals.data(), [queue_response] { return queue_response->add_tensor(); },
|
||||||
[queue_response] { return queue_response->add_shape(); },
|
[queue_response] { return queue_response->add_shape(); },
|
||||||
std::move(add_device_fn));
|
std::move(add_device_fn));
|
||||||
}
|
}
|
||||||
|
@ -224,10 +224,11 @@ void AddOperationToRunComponentFunctionRequest(
|
|||||||
const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
|
const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
|
||||||
inputs,
|
inputs,
|
||||||
const std::unordered_map<string, AttrValue>& attrs, const string& device,
|
const std::unordered_map<string, AttrValue>& attrs, const string& device,
|
||||||
RunComponentFunctionRequest* request) {
|
const int output_num, RunComponentFunctionRequest* request) {
|
||||||
auto* operation = request->mutable_operation();
|
auto* operation = request->mutable_operation();
|
||||||
operation->set_is_function(true);
|
operation->set_is_function(true);
|
||||||
operation->set_is_component_function(true);
|
operation->set_is_component_function(true);
|
||||||
|
request->add_output_num(output_num);
|
||||||
BuildOperation(operation, id, name, inputs, attrs, device);
|
BuildOperation(operation, id, name, inputs, attrs, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -610,10 +611,12 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest {
|
|||||||
RunComponentFunctionRequest run_comp_func_request;
|
RunComponentFunctionRequest run_comp_func_request;
|
||||||
run_comp_func_request.set_context_id(context_id);
|
run_comp_func_request.set_context_id(context_id);
|
||||||
RunComponentFunctionResponse run_comp_func_response;
|
RunComponentFunctionResponse run_comp_func_response;
|
||||||
|
const int output_num = 5;
|
||||||
AddOperationToRunComponentFunctionRequest(
|
AddOperationToRunComponentFunctionRequest(
|
||||||
2, function_name, {std::make_pair(1, 0)},
|
2, function_name, {std::make_pair(1, 0)},
|
||||||
std::unordered_map<string, AttrValue>(),
|
std::unordered_map<string, AttrValue>(),
|
||||||
"/job:localhost/replica:0/task:0/device:CPU:0", &run_comp_func_request);
|
"/job:localhost/replica:0/task:0/device:CPU:0", output_num,
|
||||||
|
&run_comp_func_request);
|
||||||
|
|
||||||
CallOptions call_opts;
|
CallOptions call_opts;
|
||||||
Notification n;
|
Notification n;
|
||||||
@ -636,7 +639,8 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest {
|
|||||||
const tensorflow::Tensor* t = nullptr;
|
const tensorflow::Tensor* t = nullptr;
|
||||||
tensorflow::TensorHandle* tensor_handle;
|
tensorflow::TensorHandle* tensor_handle;
|
||||||
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
||||||
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
context_id, RemoteTensorHandleInternal(2, output_num),
|
||||||
|
&tensor_handle));
|
||||||
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
||||||
|
|
||||||
auto actual = t->flat<float>();
|
auto actual = t->flat<float>();
|
||||||
|
@ -612,6 +612,9 @@ class FunctionLibraryRuntime {
|
|||||||
// infer correct device.
|
// infer correct device.
|
||||||
std::vector<string> output_devices;
|
std::vector<string> output_devices;
|
||||||
|
|
||||||
|
// If set, it indicates the original output indices of a component function.
|
||||||
|
absl::optional<std::vector<int>> ret_indices = absl::nullopt;
|
||||||
|
|
||||||
// Maps from a CompositeDevice name to a list of underlying physical
|
// Maps from a CompositeDevice name to a list of underlying physical
|
||||||
// devices.
|
// devices.
|
||||||
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
|
absl::flat_hash_map<string, const std::vector<string>*> composite_devices;
|
||||||
|
@ -180,6 +180,9 @@ message RunComponentFunctionRequest {
|
|||||||
fixed64 context_id = 1;
|
fixed64 context_id = 1;
|
||||||
|
|
||||||
Operation operation = 2;
|
Operation operation = 2;
|
||||||
|
|
||||||
|
// The output indices of its parent function.
|
||||||
|
repeated int32 output_num = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message RunComponentFunctionResponse {
|
message RunComponentFunctionResponse {
|
||||||
|
@ -92,7 +92,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])
|
self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
|
||||||
def testMultiDeviceFunctionRemoteOutput(self):
|
def testMultiDeviceFunctionRemoteOutput(self):
|
||||||
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
||||||
variable_b = variables.Variable(1)
|
variable_b = variables.Variable(1)
|
||||||
@ -101,10 +100,15 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
|||||||
def remote_output(i):
|
def remote_output(i):
|
||||||
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
||||||
c = variable_b + 1
|
c = variable_b + 1
|
||||||
return c, i + variable_b
|
return i + variable_b, c
|
||||||
|
|
||||||
self.assertAllEqual(
|
rets = remote_output(constant_op.constant([1]))
|
||||||
remote_output(constant_op.constant([1]))[0].numpy(), 2)
|
self.assertEqual(rets[0].backing_device,
|
||||||
|
'/job:localhost/replica:0/task:0/device:CPU:0')
|
||||||
|
self.assertEqual(rets[1].backing_device,
|
||||||
|
'/job:worker/replica:0/task:0/device:CPU:0')
|
||||||
|
self.assertAllEqual(rets[0].numpy(), [2])
|
||||||
|
self.assertAllEqual(rets[1].numpy(), 2)
|
||||||
|
|
||||||
def testMultiDeviceFunctionAmbiguousDevice(self):
|
def testMultiDeviceFunctionAmbiguousDevice(self):
|
||||||
|
|
||||||
@ -482,6 +486,25 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
|||||||
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
|
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
|
||||||
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
|
||||||
|
|
||||||
|
def testMultiDeviceFunctionRemoteOutput(self):
|
||||||
|
with ops.device('/job:worker/replica:0/task:1/cpu:0'):
|
||||||
|
variable_b = variables.Variable(1)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def remote_output(i):
|
||||||
|
with ops.device('/job:worker/replica:0/task:1/cpu:0'):
|
||||||
|
c = variable_b + 1
|
||||||
|
return i + variable_b, c
|
||||||
|
|
||||||
|
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
||||||
|
rets = remote_output(constant_op.constant([1]))
|
||||||
|
self.assertEqual(rets[0].backing_device,
|
||||||
|
'/job:worker/replica:0/task:0/device:CPU:0')
|
||||||
|
self.assertEqual(rets[1].backing_device,
|
||||||
|
'/job:worker/replica:0/task:1/device:CPU:0')
|
||||||
|
self.assertAllEqual(rets[0].numpy(), [2])
|
||||||
|
self.assertAllEqual(rets[1].numpy(), 2)
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
@test_util.eager_lazy_remote_copy_on_and_off
|
||||||
def testMultiDeviceWhileLoopOnRemoteDevice(self):
|
def testMultiDeviceWhileLoopOnRemoteDevice(self):
|
||||||
with ops.device('/job:worker/replica:0/task:1'):
|
with ops.device('/job:worker/replica:0/task:1'):
|
||||||
|
Loading…
Reference in New Issue
Block a user