Merge EagerPFLR and PFLR.
PiperOrigin-RevId: 305145252 Change-Id: Ie6f14c5a1af5335fefd21fbe107dc2c946dd7d66
This commit is contained in:
parent
cd8b7600af
commit
8424ef8160
@ -72,7 +72,6 @@ tf_cuda_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":eager_executor",
|
":eager_executor",
|
||||||
":kernel_and_device",
|
":kernel_and_device",
|
||||||
":process_function_library_runtime",
|
|
||||||
"//tensorflow/c:tf_tensor_internal",
|
"//tensorflow/c:tf_tensor_internal",
|
||||||
"//tensorflow/c/eager:context_interface",
|
"//tensorflow/c/eager:context_interface",
|
||||||
"//tensorflow/c/eager:tensor_handle_interface",
|
"//tensorflow/c/eager:tensor_handle_interface",
|
||||||
@ -290,7 +289,6 @@ tf_cuda_library(
|
|||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
":attr_builder",
|
":attr_builder",
|
||||||
":process_function_library_runtime",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
@ -306,31 +304,6 @@ tf_cuda_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
|
||||||
name = "process_function_library_runtime",
|
|
||||||
srcs = [
|
|
||||||
"process_function_library_runtime.cc",
|
|
||||||
],
|
|
||||||
hdrs = [
|
|
||||||
"process_function_library_runtime.h",
|
|
||||||
],
|
|
||||||
visibility = ["//tensorflow:internal"],
|
|
||||||
deps = select({
|
|
||||||
"//tensorflow:android": [
|
|
||||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
|
||||||
],
|
|
||||||
"//conditions:default": [
|
|
||||||
"@com_google_absl//absl/types:optional",
|
|
||||||
"@com_google_absl//absl/types:variant",
|
|
||||||
"//tensorflow/core:core_cpu_lib",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
],
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "kernel_and_device_test",
|
name = "kernel_and_device_test",
|
||||||
srcs = ["kernel_and_device_test.cc"],
|
srcs = ["kernel_and_device_test.cc"],
|
||||||
@ -380,7 +353,6 @@ cc_library(
|
|||||||
":eager_operation",
|
":eager_operation",
|
||||||
":kernel_and_device",
|
":kernel_and_device",
|
||||||
":tensor_handle",
|
":tensor_handle",
|
||||||
":process_function_library_runtime",
|
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
@ -36,7 +36,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/colocation_graph.h"
|
#include "tensorflow/core/common_runtime/colocation_graph.h"
|
||||||
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
|
|
||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
#include "tensorflow/core/framework/graph_def_util.h"
|
#include "tensorflow/core/framework/graph_def_util.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
@ -178,17 +177,10 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
|
|||||||
*r = CreateRendezvous(step_id);
|
*r = CreateRendezvous(step_id);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}};
|
}};
|
||||||
if (lazy_copy_function_remote_inputs_) {
|
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||||
pflr_.reset(new eager::EagerProcessFunctionLibraryRuntime(
|
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
|
||||||
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
|
thread_pool, cluster_flr, custom_kernel_creator,
|
||||||
thread_pool, cluster_flr, custom_kernel_creator,
|
/*session_metadata=*/nullptr, std::move(rendezvous_factory)));
|
||||||
/*session_metadata=*/nullptr, std::move(rendezvous_factory)));
|
|
||||||
} else {
|
|
||||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
|
||||||
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
|
|
||||||
thread_pool, cluster_flr, custom_kernel_creator,
|
|
||||||
/*session_metadata=*/nullptr, std::move(rendezvous_factory)));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void EagerContext::InitPrioritizedDeviceTypeList() {
|
void EagerContext::InitPrioritizedDeviceTypeList() {
|
||||||
|
@ -1,73 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
|
|
||||||
|
|
||||||
#include <iterator>
|
|
||||||
#include <memory>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
|
||||||
#include "tensorflow/core/util/reffed_status_callback.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace eager {
|
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
|
||||||
void EagerProcessFunctionLibraryRuntime::RunRemoteDevice(
|
|
||||||
const FunctionLibraryRuntime::Options& opts,
|
|
||||||
FunctionLibraryRuntime::Handle local_handle,
|
|
||||||
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
|
|
||||||
FunctionLibraryRuntime::DoneCallback done) const {
|
|
||||||
parent_->Run(opts, local_handle, args, rets, std::move(done));
|
|
||||||
}
|
|
||||||
|
|
||||||
void EagerProcessFunctionLibraryRuntime::Run(
|
|
||||||
const FunctionLibraryRuntime::Options& opts,
|
|
||||||
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
|
|
||||||
std::vector<Tensor>* rets,
|
|
||||||
FunctionLibraryRuntime::DoneCallback done) const {
|
|
||||||
if (!args.HasRemoteInputs()) {
|
|
||||||
return ProcessFunctionLibraryRuntime::Run(opts, handle, args, rets,
|
|
||||||
std::move(done));
|
|
||||||
}
|
|
||||||
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
|
|
||||||
done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
|
|
||||||
/*rendezvous=*/nullptr);
|
|
||||||
|
|
||||||
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
|
|
||||||
InternalArgs* comp_args) -> Status {
|
|
||||||
for (int i = 0; i < comp_data.arg_indices_.size(); ++i) {
|
|
||||||
const int index = comp_data.arg_indices_.at(i);
|
|
||||||
Tensor tensor;
|
|
||||||
if (args.GetLocalArg(index, &tensor).ok()) {
|
|
||||||
comp_args->args.push_back(std::move(tensor));
|
|
||||||
} else {
|
|
||||||
RemoteTensorHandle remote_handle;
|
|
||||||
TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle));
|
|
||||||
comp_args->remote_args.emplace_back(
|
|
||||||
absl::make_unique<RemoteTensorHandle>(std::move(remote_handle)));
|
|
||||||
comp_args->args.push_back(comp_args->remote_args.back().get());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
};
|
|
||||||
return RunMultiDevice(opts, handle, rets, cleanup_items, std::move(done),
|
|
||||||
std::move(get_component_args));
|
|
||||||
}
|
|
||||||
#endif // IS_MOBILE_PLATFORM
|
|
||||||
|
|
||||||
} // namespace eager
|
|
||||||
} // namespace tensorflow
|
|
@ -1,62 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
|
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
// Required for IS_MOBILE_PLATFORM
|
|
||||||
#include "tensorflow/core/platform/platform.h"
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
|
||||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
|
||||||
#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
|
|
||||||
#endif // IS_MOBILE_PLATFORM
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace eager {
|
|
||||||
|
|
||||||
// A ProcessFunctionLibraryRuntime which supports running functions with inputs
|
|
||||||
// on remote devices.
|
|
||||||
// TODO(b/134094971): Support outputting tensors on remote devices.
|
|
||||||
class EagerProcessFunctionLibraryRuntime
|
|
||||||
: public ProcessFunctionLibraryRuntime {
|
|
||||||
public:
|
|
||||||
using ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime;
|
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
|
||||||
void Run(const FunctionLibraryRuntime::Options& opts,
|
|
||||||
FunctionLibraryRuntime::Handle handle,
|
|
||||||
const FunctionArgsInterface& args, std::vector<Tensor>* rets,
|
|
||||||
FunctionLibraryRuntime::DoneCallback done) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
void RunRemoteDevice(
|
|
||||||
const FunctionLibraryRuntime::Options& opts,
|
|
||||||
FunctionLibraryRuntime::Handle local_handle,
|
|
||||||
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
|
|
||||||
FunctionLibraryRuntime::DoneCallback done) const override;
|
|
||||||
#endif // IS_MOBILE_PLATFORM
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace eager
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
|
|
@ -49,6 +49,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/dump_graph.h"
|
#include "tensorflow/core/util/dump_graph.h"
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
#include "tensorflow/core/util/reffed_status_callback.h"
|
#include "tensorflow/core/util/reffed_status_callback.h"
|
||||||
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
|
||||||
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -973,14 +976,6 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessFunctionLibraryRuntime::RunRemoteDevice(
|
|
||||||
const FunctionLibraryRuntime::Options& opts,
|
|
||||||
FunctionLibraryRuntime::Handle local_handle,
|
|
||||||
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
|
|
||||||
FunctionLibraryRuntime::DoneCallback done) const {
|
|
||||||
parent_->Run(opts, local_handle, GetLocalArgs(args), rets, std::move(done));
|
|
||||||
}
|
|
||||||
|
|
||||||
void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
void ProcessFunctionLibraryRuntime::RunMultiDevice(
|
||||||
const FunctionLibraryRuntime::Options& opts,
|
const FunctionLibraryRuntime::Options& opts,
|
||||||
FunctionLibraryRuntime::Handle handle, std::vector<Tensor>* rets,
|
FunctionLibraryRuntime::Handle handle, std::vector<Tensor>* rets,
|
||||||
@ -1426,7 +1421,7 @@ void ProcessFunctionLibraryRuntime::RunInternal(
|
|||||||
cleanup_item->step_id = opts.step_id;
|
cleanup_item->step_id = opts.step_id;
|
||||||
cleanup_item->local_handle = local_handle;
|
cleanup_item->local_handle = local_handle;
|
||||||
cleanup_items->emplace_back(std::move(cleanup_item));
|
cleanup_items->emplace_back(std::move(cleanup_item));
|
||||||
RunRemoteDevice(opts, local_handle, args, rets, std::move(done));
|
parent_->Run(opts, local_handle, args, rets, std::move(done));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
done(errors::Internal("Could not find device"));
|
done(errors::Internal("Could not find device"));
|
||||||
@ -1483,8 +1478,41 @@ void ProcessFunctionLibraryRuntime::Run(
|
|||||||
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
|
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
|
||||||
std::vector<Tensor>* rets,
|
std::vector<Tensor>* rets,
|
||||||
FunctionLibraryRuntime::DoneCallback done) const {
|
FunctionLibraryRuntime::DoneCallback done) const {
|
||||||
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
|
if (!args.HasRemoteInputs()) {
|
||||||
Run(opts, handle, local_inputs, rets, std::move(done));
|
const std::vector<Tensor> local_inputs = args.GetLocalTensors();
|
||||||
|
return Run(opts, handle, local_inputs, rets, std::move(done));
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(IS_MOBILE_PLATFORM)
|
||||||
|
done(errors::Unimplemented(
|
||||||
|
"Remote inputs are not available on mobile devices."));
|
||||||
|
return;
|
||||||
|
#else // !IS_MOBILE_PLATFORM
|
||||||
|
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
|
||||||
|
done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
|
||||||
|
/*rendezvous=*/nullptr);
|
||||||
|
|
||||||
|
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
|
||||||
|
InternalArgs* comp_args) -> Status {
|
||||||
|
for (int i = 0; i < comp_data.arg_indices_.size(); ++i) {
|
||||||
|
const int index = comp_data.arg_indices_.at(i);
|
||||||
|
Tensor tensor;
|
||||||
|
if (args.GetLocalArg(index, &tensor).ok()) {
|
||||||
|
comp_args->args.push_back(std::move(tensor));
|
||||||
|
} else {
|
||||||
|
eager::RemoteTensorHandle remote_handle;
|
||||||
|
TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle));
|
||||||
|
comp_args->remote_args.emplace_back(
|
||||||
|
absl::make_unique<eager::RemoteTensorHandle>(
|
||||||
|
std::move(remote_handle)));
|
||||||
|
comp_args->args.push_back(comp_args->remote_args.back().get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
return RunMultiDevice(opts, handle, rets, cleanup_items, std::move(done),
|
||||||
|
std::move(get_component_args));
|
||||||
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessFunctionLibraryRuntime::CleanUp(
|
void ProcessFunctionLibraryRuntime::CleanUp(
|
||||||
|
@ -73,7 +73,7 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
const SessionMetadata* session_metadata = nullptr,
|
const SessionMetadata* session_metadata = nullptr,
|
||||||
Rendezvous::Factory rendezvous_factory = Rendezvous::Factory());
|
Rendezvous::Factory rendezvous_factory = Rendezvous::Factory());
|
||||||
|
|
||||||
virtual ~ProcessFunctionLibraryRuntime() {
|
~ProcessFunctionLibraryRuntime() {
|
||||||
// Deleting the FunctionLibraryRuntime map will delete the function handles
|
// Deleting the FunctionLibraryRuntime map will delete the function handles
|
||||||
// registered in it, which may call ReleaseHandle in this class again to
|
// registered in it, which may call ReleaseHandle in this class again to
|
||||||
// release their sub-function. These circular calls may casue segfault
|
// release their sub-function. These circular calls may casue segfault
|
||||||
@ -184,10 +184,10 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
|
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
|
||||||
FunctionLibraryRuntime::DoneCallback done) const;
|
FunctionLibraryRuntime::DoneCallback done) const;
|
||||||
|
|
||||||
virtual void Run(const FunctionLibraryRuntime::Options& opts,
|
void Run(const FunctionLibraryRuntime::Options& opts,
|
||||||
FunctionLibraryRuntime::Handle handle,
|
FunctionLibraryRuntime::Handle handle,
|
||||||
const FunctionArgsInterface& args, std::vector<Tensor>* rets,
|
const FunctionArgsInterface& args, std::vector<Tensor>* rets,
|
||||||
FunctionLibraryRuntime::DoneCallback done) const;
|
FunctionLibraryRuntime::DoneCallback done) const;
|
||||||
|
|
||||||
const DeviceMgr* device_mgr() { return device_mgr_; }
|
const DeviceMgr* device_mgr() { return device_mgr_; }
|
||||||
|
|
||||||
@ -275,12 +275,6 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
FunctionLibraryRuntime::Handle local_handle;
|
FunctionLibraryRuntime::Handle local_handle;
|
||||||
};
|
};
|
||||||
|
|
||||||
virtual void RunRemoteDevice(const FunctionLibraryRuntime::Options& opts,
|
|
||||||
FunctionLibraryRuntime::Handle local_handle,
|
|
||||||
gtl::ArraySlice<FunctionArg> args,
|
|
||||||
std::vector<Tensor>* rets,
|
|
||||||
FunctionLibraryRuntime::DoneCallback done) const;
|
|
||||||
|
|
||||||
// If `handle` represents a multi-device function, returns the multi-device
|
// If `handle` represents a multi-device function, returns the multi-device
|
||||||
// data associated with `handle`. Else, nullptr.
|
// data associated with `handle`. Else, nullptr.
|
||||||
MultiDeviceFunctionData* IsMultiDevice(
|
MultiDeviceFunctionData* IsMultiDevice(
|
||||||
|
@ -330,6 +330,25 @@ void ClusterFunctionLibraryRuntime::Run(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ClusterFunctionLibraryRuntime::Run(
|
||||||
|
const FunctionLibraryRuntime::Options& opts,
|
||||||
|
FunctionLibraryRuntime::LocalHandle handle,
|
||||||
|
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
|
||||||
|
FunctionLibraryRuntime::DoneCallback done) {
|
||||||
|
std::vector<Tensor> tensors;
|
||||||
|
for (const auto& arg : args) {
|
||||||
|
if (arg.index() == 0) {
|
||||||
|
tensors.push_back(absl::get<Tensor>(arg));
|
||||||
|
} else {
|
||||||
|
done(
|
||||||
|
errors::Internal("ClusterFunctionLibraryRuntime doesn't support "
|
||||||
|
"eager::RemoteTensorHandle."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Run(opts, handle, tensors, rets, std::move(done));
|
||||||
|
}
|
||||||
|
|
||||||
void ClusterFunctionLibraryRuntime::CleanUp(
|
void ClusterFunctionLibraryRuntime::CleanUp(
|
||||||
uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
|
uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
|
||||||
FunctionLibraryRuntime::DoneCallback done) {
|
FunctionLibraryRuntime::DoneCallback done) {
|
||||||
|
@ -47,6 +47,11 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
|
|||||||
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
||||||
FunctionLibraryRuntime::DoneCallback done) override;
|
FunctionLibraryRuntime::DoneCallback done) override;
|
||||||
|
|
||||||
|
void Run(const FunctionLibraryRuntime::Options& opts,
|
||||||
|
FunctionLibraryRuntime::LocalHandle handle,
|
||||||
|
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
|
||||||
|
FunctionLibraryRuntime::DoneCallback done) override;
|
||||||
|
|
||||||
void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
|
void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
|
||||||
FunctionLibraryRuntime::DoneCallback done) override;
|
FunctionLibraryRuntime::DoneCallback done) override;
|
||||||
|
|
||||||
|
@ -112,7 +112,6 @@ cc_library(
|
|||||||
"//tensorflow/core/common_runtime/eager:core",
|
"//tensorflow/core/common_runtime/eager:core",
|
||||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||||
"//tensorflow/core/common_runtime/eager:execute",
|
"//tensorflow/core/common_runtime/eager:execute",
|
||||||
"//tensorflow/core/common_runtime/eager:process_function_library_runtime",
|
|
||||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
"//tensorflow/core/distributed_runtime:message_wrappers",
|
"//tensorflow/core/distributed_runtime:message_wrappers",
|
||||||
"//tensorflow/core/distributed_runtime:server_lib",
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
@ -146,7 +145,6 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||||
"//tensorflow/core/common_runtime/eager:process_function_library_runtime",
|
|
||||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||||
"//tensorflow/core/distributed_runtime:test_utils",
|
"//tensorflow/core/distributed_runtime:test_utils",
|
||||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||||||
#include "absl/types/variant.h"
|
#include "absl/types/variant.h"
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
|
|
||||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||||
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
|
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
|
||||||
@ -554,7 +553,7 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
|
|||||||
|
|
||||||
fdef_ = MatMulFunction();
|
fdef_ = MatMulFunction();
|
||||||
TF_ASSERT_OK(func_lib_def_.AddFunctionDef(fdef_));
|
TF_ASSERT_OK(func_lib_def_.AddFunctionDef(fdef_));
|
||||||
eager_pflr_ = absl::make_unique<EagerProcessFunctionLibraryRuntime>(
|
eager_pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||||
remote_device_mgr_.get(), Env::Default(), /*config=*/
|
remote_device_mgr_.get(), Env::Default(), /*config=*/
|
||||||
nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
|
nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
|
||||||
/*thread_pool=*/nullptr, eager_cluster_flr_.get());
|
/*thread_pool=*/nullptr, eager_cluster_flr_.get());
|
||||||
@ -598,7 +597,7 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Test executes a remote function through
|
// Test executes a remote function through
|
||||||
// EagerProcessFunctionLibraryRuntime(EagerClusterFunctionLibraryRuntime).
|
// ProcessFunctionLibraryRuntime(EagerClusterFunctionLibraryRuntime).
|
||||||
TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
|
TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
|
||||||
Init();
|
Init();
|
||||||
// Instantiate MatMulFunction on remote_device.
|
// Instantiate MatMulFunction on remote_device.
|
||||||
|
@ -867,7 +867,6 @@ class DistributedFunctionLibraryRuntime {
|
|||||||
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
||||||
FunctionLibraryRuntime::DoneCallback done) = 0;
|
FunctionLibraryRuntime::DoneCallback done) = 0;
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
|
||||||
// TODO(yujingzhang): Support outputting tensors on remote devices.
|
// TODO(yujingzhang): Support outputting tensors on remote devices.
|
||||||
virtual void Run(const FunctionLibraryRuntime::Options& opts,
|
virtual void Run(const FunctionLibraryRuntime::Options& opts,
|
||||||
FunctionLibraryRuntime::LocalHandle handle,
|
FunctionLibraryRuntime::LocalHandle handle,
|
||||||
@ -875,7 +874,6 @@ class DistributedFunctionLibraryRuntime {
|
|||||||
FunctionLibraryRuntime::DoneCallback done) {
|
FunctionLibraryRuntime::DoneCallback done) {
|
||||||
done(errors::Unimplemented("Unimplemented."));
|
done(errors::Unimplemented("Unimplemented."));
|
||||||
}
|
}
|
||||||
#endif // IS_MOBILE_PLATFORM
|
|
||||||
|
|
||||||
virtual void CleanUp(uint64 step_id,
|
virtual void CleanUp(uint64 step_id,
|
||||||
FunctionLibraryRuntime::LocalHandle handle,
|
FunctionLibraryRuntime::LocalHandle handle,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user