Merge EagerPFLR and PFLR.

PiperOrigin-RevId: 305145252
Change-Id: Ie6f14c5a1af5335fefd21fbe107dc2c946dd7d66
This commit is contained in:
Yujing Zhang 2020-04-06 16:40:21 -07:00 committed by TensorFlower Gardener
parent cd8b7600af
commit 8424ef8160
11 changed files with 74 additions and 204 deletions

View File

@ -72,7 +72,6 @@ tf_cuda_library(
deps = [
":eager_executor",
":kernel_and_device",
":process_function_library_runtime",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:context_interface",
"//tensorflow/c/eager:tensor_handle_interface",
@ -290,7 +289,6 @@ tf_cuda_library(
visibility = ["//tensorflow:internal"],
deps = [
":attr_builder",
":process_function_library_runtime",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
@ -306,31 +304,6 @@ tf_cuda_library(
}),
)
tf_cuda_library(
name = "process_function_library_runtime",
srcs = [
"process_function_library_runtime.cc",
],
hdrs = [
"process_function_library_runtime.h",
],
visibility = ["//tensorflow:internal"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
}),
)
tf_cc_test(
name = "kernel_and_device_test",
srcs = ["kernel_and_device_test.cc"],
@ -380,7 +353,6 @@ cc_library(
":eager_operation",
":kernel_and_device",
":tensor_handle",
":process_function_library_runtime",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/strings",

View File

@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/colocation_graph.h"
#include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/function.h"
@ -178,17 +177,10 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
*r = CreateRendezvous(step_id);
return Status::OK();
}};
if (lazy_copy_function_remote_inputs_) {
pflr_.reset(new eager::EagerProcessFunctionLibraryRuntime(
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
thread_pool, cluster_flr, custom_kernel_creator,
/*session_metadata=*/nullptr, std::move(rendezvous_factory)));
} else {
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
thread_pool, cluster_flr, custom_kernel_creator,
/*session_metadata=*/nullptr, std::move(rendezvous_factory)));
}
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
thread_pool, cluster_flr, custom_kernel_creator,
/*session_metadata=*/nullptr, std::move(rendezvous_factory)));
}
void EagerContext::InitPrioritizedDeviceTypeList() {

View File

@ -1,73 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
#include <iterator>
#include <memory>
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/util/reffed_status_callback.h"
namespace tensorflow {
namespace eager {
#if !defined(IS_MOBILE_PLATFORM)
void EagerProcessFunctionLibraryRuntime::RunRemoteDevice(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle local_handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
parent_->Run(opts, local_handle, args, rets, std::move(done));
}
void EagerProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
if (!args.HasRemoteInputs()) {
return ProcessFunctionLibraryRuntime::Run(opts, handle, args, rets,
std::move(done));
}
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
/*rendezvous=*/nullptr);
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
InternalArgs* comp_args) -> Status {
for (int i = 0; i < comp_data.arg_indices_.size(); ++i) {
const int index = comp_data.arg_indices_.at(i);
Tensor tensor;
if (args.GetLocalArg(index, &tensor).ok()) {
comp_args->args.push_back(std::move(tensor));
} else {
RemoteTensorHandle remote_handle;
TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle));
comp_args->remote_args.emplace_back(
absl::make_unique<RemoteTensorHandle>(std::move(remote_handle)));
comp_args->args.push_back(comp_args->remote_args.back().get());
}
}
return Status::OK();
};
return RunMultiDevice(opts, handle, rets, cleanup_items, std::move(done),
std::move(get_component_args));
}
#endif // IS_MOBILE_PLATFORM
} // namespace eager
} // namespace tensorflow

View File

@ -1,62 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
#include <memory>
#include <unordered_map>
// clang-format off
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
#endif // IS_MOBILE_PLATFORM
namespace tensorflow {
namespace eager {
// A ProcessFunctionLibraryRuntime which supports running functions with inputs
// on remote devices.
// TODO(b/134094971): Support outputting tensors on remote devices.
class EagerProcessFunctionLibraryRuntime
: public ProcessFunctionLibraryRuntime {
public:
using ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime;
#if !defined(IS_MOBILE_PLATFORM)
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
const FunctionArgsInterface& args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const override;
private:
void RunRemoteDevice(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle local_handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const override;
#endif // IS_MOBILE_PLATFORM
};
} // namespace eager
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_

View File

@ -49,6 +49,9 @@ limitations under the License.
#include "tensorflow/core/util/dump_graph.h"
#include "tensorflow/core/util/ptr_util.h"
#include "tensorflow/core/util/reffed_status_callback.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
#endif // IS_MOBILE_PLATFORM
namespace tensorflow {
@ -973,14 +976,6 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
return Status::OK();
}
void ProcessFunctionLibraryRuntime::RunRemoteDevice(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle local_handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
parent_->Run(opts, local_handle, GetLocalArgs(args), rets, std::move(done));
}
void ProcessFunctionLibraryRuntime::RunMultiDevice(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, std::vector<Tensor>* rets,
@ -1426,7 +1421,7 @@ void ProcessFunctionLibraryRuntime::RunInternal(
cleanup_item->step_id = opts.step_id;
cleanup_item->local_handle = local_handle;
cleanup_items->emplace_back(std::move(cleanup_item));
RunRemoteDevice(opts, local_handle, args, rets, std::move(done));
parent_->Run(opts, local_handle, args, rets, std::move(done));
return;
}
done(errors::Internal("Could not find device"));
@ -1483,8 +1478,41 @@ void ProcessFunctionLibraryRuntime::Run(
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
Run(opts, handle, local_inputs, rets, std::move(done));
if (!args.HasRemoteInputs()) {
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
return Run(opts, handle, local_inputs, rets, std::move(done));
}
#if defined(IS_MOBILE_PLATFORM)
done(errors::Unimplemented(
"Remote inputs are not available on mobile devices."));
return;
#else // !IS_MOBILE_PLATFORM
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
/*rendezvous=*/nullptr);
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
InternalArgs* comp_args) -> Status {
for (int i = 0; i < comp_data.arg_indices_.size(); ++i) {
const int index = comp_data.arg_indices_.at(i);
Tensor tensor;
if (args.GetLocalArg(index, &tensor).ok()) {
comp_args->args.push_back(std::move(tensor));
} else {
eager::RemoteTensorHandle remote_handle;
TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle));
comp_args->remote_args.emplace_back(
absl::make_unique<eager::RemoteTensorHandle>(
std::move(remote_handle)));
comp_args->args.push_back(comp_args->remote_args.back().get());
}
}
return Status::OK();
};
return RunMultiDevice(opts, handle, rets, cleanup_items, std::move(done),
std::move(get_component_args));
#endif // !IS_MOBILE_PLATFORM
}
void ProcessFunctionLibraryRuntime::CleanUp(

View File

@ -73,7 +73,7 @@ class ProcessFunctionLibraryRuntime {
const SessionMetadata* session_metadata = nullptr,
Rendezvous::Factory rendezvous_factory = Rendezvous::Factory());
virtual ~ProcessFunctionLibraryRuntime() {
~ProcessFunctionLibraryRuntime() {
// Deleting the FunctionLibraryRuntime map will delete the function handles
// registered in it, which may call ReleaseHandle in this class again to
// release their sub-function. These circular calls may casue segfault
@ -184,10 +184,10 @@ class ProcessFunctionLibraryRuntime {
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
FunctionLibraryRuntime::DoneCallback done) const;
virtual void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
const FunctionArgsInterface& args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const;
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
const FunctionArgsInterface& args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const;
const DeviceMgr* device_mgr() { return device_mgr_; }
@ -275,12 +275,6 @@ class ProcessFunctionLibraryRuntime {
FunctionLibraryRuntime::Handle local_handle;
};
virtual void RunRemoteDevice(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle local_handle,
gtl::ArraySlice<FunctionArg> args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const;
// If `handle` represents a multi-device function, returns the multi-device
// data associated with `handle`. Else, nullptr.
MultiDeviceFunctionData* IsMultiDevice(

View File

@ -330,6 +330,25 @@ void ClusterFunctionLibraryRuntime::Run(
});
}
void ClusterFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) {
std::vector<Tensor> tensors;
for (const auto& arg : args) {
if (arg.index() == 0) {
tensors.push_back(absl::get<Tensor>(arg));
} else {
done(
errors::Internal("ClusterFunctionLibraryRuntime doesn't support "
"eager::RemoteTensorHandle."));
return;
}
}
return Run(opts, handle, tensors, rets, std::move(done));
}
void ClusterFunctionLibraryRuntime::CleanUp(
uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
FunctionLibraryRuntime::DoneCallback done) {

View File

@ -47,6 +47,11 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) override;
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) override;
void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
FunctionLibraryRuntime::DoneCallback done) override;

View File

@ -112,7 +112,6 @@ cc_library(
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:process_function_library_runtime",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:message_wrappers",
"//tensorflow/core/distributed_runtime:server_lib",
@ -146,7 +145,6 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:process_function_library_runtime",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:test_utils",

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "absl/types/variant.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
@ -554,7 +553,7 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
fdef_ = MatMulFunction();
TF_ASSERT_OK(func_lib_def_.AddFunctionDef(fdef_));
eager_pflr_ = absl::make_unique<EagerProcessFunctionLibraryRuntime>(
eager_pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
remote_device_mgr_.get(), Env::Default(), /*config=*/
nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
/*thread_pool=*/nullptr, eager_cluster_flr_.get());
@ -598,7 +597,7 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
};
// Test executes a remote function through
// EagerProcessFunctionLibraryRuntime(EagerClusterFunctionLibraryRuntime).
// ProcessFunctionLibraryRuntime(EagerClusterFunctionLibraryRuntime).
TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
Init();
// Instantiate MatMulFunction on remote_device.

View File

@ -867,7 +867,6 @@ class DistributedFunctionLibraryRuntime {
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) = 0;
#if !defined(IS_MOBILE_PLATFORM)
// TODO(yujingzhang): Support outputting tensors on remote devices.
virtual void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
@ -875,7 +874,6 @@ class DistributedFunctionLibraryRuntime {
FunctionLibraryRuntime::DoneCallback done) {
done(errors::Unimplemented("Unimplemented."));
}
#endif // IS_MOBILE_PLATFORM
virtual void CleanUp(uint64 step_id,
FunctionLibraryRuntime::LocalHandle handle,