Change the function output type, either a Tensor for a local output or a TensorShape for a remote output, preparing for the support of function outputs placed on remote workers.

PiperOrigin-RevId: 324938354
Change-Id: I126822bd75bb284c917af7a72f2868601e798f09
This commit is contained in:
Yujing Zhang 2020-08-04 19:09:39 -07:00 committed by TensorFlower Gardener
parent e95a955af8
commit 6388aa43d7
10 changed files with 144 additions and 42 deletions

View File

@ -395,13 +395,25 @@ void KernelAndDeviceFunc::RunAsync(
},
profiler::ContextType::kTfExecutor, opts->step_id,
profiler::TraceMeLevel::kInfo);
pflr_->Run(*opts, handle_, inputs, outputs,
[opts, rendezvous, local_cm, step_container, this,
done = std::move(done)](const Status& s) {
std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
pflr_->Run(*opts, handle_, inputs, function_rets,
[opts, outputs, function_rets, rendezvous, local_cm,
step_container, this, done = std::move(done)](const Status& s) {
rendezvous->Unref();
if (step_container == nullptr) {
this->step_container_.CleanUp();
}
if (s.ok()) {
// TODO(b/162618595): Change the type of `outputs` to
// support TensorShapes for remote outputs and remove the
// FunctionRet to Tensor conversion here.
for (const auto& ret : *function_rets) {
if (ret.index() == 0) {
outputs->push_back(absl::get<Tensor>(ret));
}
}
}
delete function_rets;
done(s);
});
}

View File

@ -398,6 +398,21 @@ std::vector<Tensor> GetLocalArgs(gtl::ArraySlice<FunctionArg> args) {
return tensors;
}
// Update the done callback to push Tensors in `tensors` into `rets`.
FunctionLibraryRuntime::DoneCallback TensorsToFunctionRetsDoneCallback(
std::vector<FunctionRet>* rets, std::vector<Tensor>* tensors,
FunctionLibraryRuntime::DoneCallback done) {
return [rets, tensors, done = std::move(done)](const Status& s) {
if (s.ok()) {
for (const auto& t : *tensors) {
rets->push_back(t);
}
}
delete tensors;
done(s);
};
}
} // anonymous namespace
Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
@ -1021,7 +1036,7 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
void ProcessFunctionLibraryRuntime::RunMultiDevice(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, std::vector<Tensor>* rets,
FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets,
std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
FunctionLibraryRuntime::DoneCallback done,
std::function<Status(const ComponentFunctionData& comp_data,
@ -1097,7 +1112,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
cm->StartCancel();
continue;
}
std::vector<Tensor>* comp_rets = new std::vector<Tensor>;
std::vector<FunctionRet>* comp_rets = new std::vector<FunctionRet>;
rets->resize(data->num_outputs_);
auto component_fn_callback = [comp_rets, rets, comp_data, refcounted_done,
@ -1136,8 +1151,11 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice(
<< " with handle " << handle;
VLOG(4) << " with " << opts_copy.DebugString();
flr->Run(opts_copy, handle, GetLocalArgs(comp_args.args), comp_rets,
std::move(component_fn_callback));
std::vector<Tensor>* comp_tensor_rets = new std::vector<Tensor>;
flr->Run(
opts_copy, handle, GetLocalArgs(comp_args.args), comp_tensor_rets,
TensorsToFunctionRetsDoneCallback(comp_rets, comp_tensor_rets,
std::move(component_fn_callback)));
} else {
opts_copy.remote_execution = true;
@ -1362,6 +1380,23 @@ void ProcessFunctionLibraryRuntime::Run(
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done),
new_opts.step_id, created_rendezvous);
std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
done = [rets, function_rets, done = std::move(done)](const Status& s) {
Status status = s;
if (status.ok()) {
for (const auto& ret : *function_rets) {
if (ret.index() == 0) {
rets->push_back(absl::get<Tensor>(ret));
} else {
status.Update(errors::Internal(
"Expect a Tensor as a function output but got a TensorShape."));
break;
}
}
}
delete function_rets;
done(status);
};
bool multi_device;
{
tf_shared_lock l(mu_);
@ -1392,21 +1427,21 @@ void ProcessFunctionLibraryRuntime::Run(
}
return Status::OK();
};
return RunMultiDevice(new_opts, handle, rets, cleanup_items,
return RunMultiDevice(new_opts, handle, function_rets, cleanup_items,
std::move(done), std::move(get_component_args));
}
std::vector<FunctionArg> local_args;
for (const auto& tensor : args) {
local_args.push_back(tensor);
}
RunInternal(new_opts, handle, local_args, rets, cleanup_items,
RunInternal(new_opts, handle, local_args, function_rets, cleanup_items,
std::move(done));
}
void ProcessFunctionLibraryRuntime::RunInternal(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<FunctionArg> args,
std::vector<Tensor>* rets,
std::vector<FunctionRet>* rets,
std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
FunctionLibraryRuntime::DoneCallback done) const {
FunctionLibraryRuntime* flr = nullptr;
@ -1475,10 +1510,13 @@ void ProcessFunctionLibraryRuntime::RunInternal(
int64 num_returns = remote_rets->size();
delete remote_rets;
// Now receive the return values from the target.
std::vector<Tensor>* recv_tensors = new std::vector<Tensor>;
ReceiveTensorsAsync(target_device, source_device, "ret_",
target_incarnation, num_returns,
device_context, rets_alloc_attrs, rendezvous,
rets, std::move(done));
recv_tensors,
TensorsToFunctionRetsDoneCallback(
rets, recv_tensors, std::move(done)));
});
return;
}
@ -1570,11 +1608,14 @@ Status ProcessFunctionLibraryRuntime::RunSync(
void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
std::vector<Tensor>* rets,
std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
if (!args.HasRemoteOrPackedInputs()) {
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
return Run(opts, handle, local_inputs, rets, std::move(done));
std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
return Run(
opts, handle, local_inputs, tensor_rets,
TensorsToFunctionRetsDoneCallback(rets, tensor_rets, std::move(done)));
}
FunctionLibraryRuntime::Options new_opts = opts;

View File

@ -191,7 +191,7 @@ class ProcessFunctionLibraryRuntime {
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
const FunctionArgsInterface& args, std::vector<Tensor>* rets,
const FunctionArgsInterface& args, std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) const;
Status RunSync(const FunctionLibraryRuntime::Options& opts,
@ -304,7 +304,7 @@ class ProcessFunctionLibraryRuntime {
void RunMultiDevice(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, std::vector<Tensor>* rets,
FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets,
std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
FunctionLibraryRuntime::DoneCallback done,
std::function<Status(const ComponentFunctionData& comp_data,
@ -388,7 +388,8 @@ class ProcessFunctionLibraryRuntime {
void RunInternal(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
gtl::ArraySlice<FunctionArg> args,
std::vector<FunctionRet>* rets,
std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
FunctionLibraryRuntime::DoneCallback done) const;

View File

@ -72,7 +72,7 @@ class TestClusterFLR : public DistributedFunctionLibraryRuntime {
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) override {}
void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
@ -209,12 +209,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
}
template <typename T>
template <typename T, typename K>
Status RunWithRuntime(
const string& name, FunctionLibraryRuntime::Options opts,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
const T& args, std::vector<Tensor*> rets,
const T& args, std::vector<K*> rets,
ProcessFunctionLibraryRuntime* pflr) {
FunctionLibraryRuntime::Handle handle;
Status status = pflr->Instantiate(name, attrs, instantiate_opts, &handle);
@ -234,7 +234,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
Notification done;
opts.runner = &runner;
std::vector<Tensor> out;
std::vector<K> out;
pflr->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
status = s;
done.Notify();
@ -273,7 +273,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
ProcessFunctionLibraryRuntime* pflr = nullptr) {
return RunWithRuntime<std::vector<Tensor>>(
return RunWithRuntime<std::vector<Tensor>, Tensor>(
name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
}
@ -281,9 +281,9 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
const string& name, FunctionLibraryRuntime::Options opts,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
const FunctionArgsInterface& args, std::vector<Tensor*> rets,
const FunctionArgsInterface& args, std::vector<FunctionRet*> rets,
ProcessFunctionLibraryRuntime* pflr = nullptr) {
return RunWithRuntime<FunctionArgsInterface>(
return RunWithRuntime<FunctionArgsInterface, FunctionRet>(
name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
}
@ -879,10 +879,12 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
handles.push_back(TensorValue(&resource_handle0));
handles.push_back(TensorValue(&resource_handle1));
TestFunctionPackedArgs args(0, std::move(handles));
Tensor ret;
FunctionRet ret;
TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts,
{{"T", DT_FLOAT}}, inst_opts, args, {&ret}));
test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
EXPECT_EQ(ret.index(), 0);
test::ExpectTensorEqual<float>(absl::get<Tensor>(ret),
test::AsTensor<float>({40, 60}));
}
// Packed Tensor
@ -1226,9 +1228,10 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
const auto x = test::AsTensor<int64>({17});
Tensor y;
TF_CHECK_OK(RunWithRuntime<std::vector<Tensor>>(
Status s = RunWithRuntime<std::vector<Tensor>, Tensor>(
"SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y},
cloned_proc_flr.get()));
cloned_proc_flr.get());
TF_CHECK_OK(s);
SessionMetadata read_metadata;
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar<tstring>()(),
&read_metadata));

View File

@ -333,7 +333,7 @@ void ClusterFunctionLibraryRuntime::Run(
void ClusterFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) {
std::vector<Tensor> tensors;
for (const auto& arg : args) {
@ -346,7 +346,17 @@ void ClusterFunctionLibraryRuntime::Run(
return;
}
}
return Run(opts, handle, tensors, rets, std::move(done));
std::vector<Tensor>* ret_tensors = new std::vector<Tensor>;
return Run(opts, handle, tensors, ret_tensors,
[rets, ret_tensors, done = std::move(done)](const Status& s) {
if (s.ok()) {
for (const auto& t : *ret_tensors) {
rets->push_back(t);
}
}
delete ret_tensors;
done(s);
});
}
void ClusterFunctionLibraryRuntime::CleanUp(

View File

@ -49,7 +49,7 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) override;
void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,

View File

@ -118,13 +118,31 @@ void EagerClusterFunctionLibraryRuntime::Run(
for (const auto& tensor : args) {
function_args.push_back(tensor);
}
Run(opts, handle, function_args, rets, std::move(done));
std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
Run(opts, handle, function_args, function_rets,
[rets, function_rets, done = std::move(done)](const Status& s) {
Status status = s;
if (status.ok()) {
for (const auto& t : *function_rets) {
if (t.index() == 0) {
rets->push_back(absl::get<Tensor>(t));
} else {
status.Update(
errors::Internal("Expect a Tensor as a remote function "
"output but got a TensorShape."));
break;
}
}
}
delete function_rets;
done(status);
});
}
void EagerClusterFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) {
FunctionData* function_data = nullptr;
{
@ -204,6 +222,14 @@ void EagerClusterFunctionLibraryRuntime::Run(
done(s);
return;
}
if (!response->shape().empty() && !response->tensor().empty()) {
done(errors::Internal(
"Both shape and tensor are specified in the same response"));
return;
}
for (const auto& shape : response->shape()) {
rets->push_back(shape);
}
for (const auto& tensor_proto : response->tensor()) {
Tensor t;
if (t.FromProto(tensor_proto)) {

View File

@ -64,11 +64,12 @@ class EagerClusterFunctionLibraryRuntime
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) override;
// The component function inputs `args` can be RemoteTensorHandles, which will
// be lazily resolved remotely where the inputs are actually consumed.
// The component function inputs `args` and outputs `rets` may refer to remote
// tensors on a remote device, which will be lazily resolved remotely where
// the inputs/outputs are actually consumed.
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) override;
void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,

View File

@ -830,7 +830,7 @@ TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
input.set_op_device(local_device_);
input.set_device(local_device_);
std::vector<RemoteTensorHandle> inputs = {input};
std::vector<Tensor> outputs;
std::vector<FunctionRet> outputs;
gtl::InlinedVector<TensorValue, 4> tensor_args = {TensorValue()};
TestExecuteNodeArgs args(
std::move(tensor_args),
@ -845,6 +845,10 @@ TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
});
done.WaitForNotification();
TF_ASSERT_OK(status);
EXPECT_EQ(outputs.size(), 1);
EXPECT_EQ(outputs.at(0).index(), 1);
const TensorShape& shape = absl::get<TensorShape>(outputs.at(0));
EXPECT_EQ(shape, TensorShape({2, 2}));
CheckOutputsAndClose(op_id);
}

View File

@ -901,6 +901,9 @@ typedef
FunctionArg;
#endif
// Either a local tensor or the shape of a remote tensor.
typedef absl::variant<Tensor, TensorShape> FunctionRet;
// Used to instantiate and run functions in a distributed system.
class DistributedFunctionLibraryRuntime {
public:
@ -929,14 +932,15 @@ class DistributedFunctionLibraryRuntime {
// Run an instantiated remote function (specified by `handle`) with a list of
// input Tensors or RemoteTensorHandles as `args` and get its output Tensors
// in `rets`. When using RemoteTensorHandles as function inputs, the
// corresponding tensor data will be resolved on the remote worker, so it is
// not required to be locally available on the caller side. Using
// RemoteTensorHandle inputs is not supported in TensorFlow v1 runtime.
// TODO(yujingzhang): Support outputting tensors on remote devices.
// or TensorShapes in `rets`. When using RemoteTensorHandles as function
// inputs or TensorShapes as outputs, the corresponding tensor data will be
// resolved on the remote worker, so it is not required to be locally
// available on the caller side. Using RemoteTensorHandle inputs is not
// supported in TensorFlow v1 runtime.
virtual void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
gtl::ArraySlice<FunctionArg> args,
std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) = 0;
// Clean up a previously instantiated function on remote worker.