Move distributed initializing and updating logic from C API to eager ContextDistributedManager.
PiperOrigin-RevId: 342981979 Change-Id: I56f0d7d71ca72d4194a6ec4ca8d228f9733f9bde
This commit is contained in:
parent
1d57024588
commit
9da79c731c
@ -51,6 +51,7 @@ tf_cuda_library(
|
||||
":immediate_execution_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
":immediate_execution_distributed_manager",
|
||||
":abstract_tensor_handle",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
@ -70,6 +71,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
@ -119,6 +121,7 @@ filegroup(
|
||||
"gradients.h",
|
||||
"gradients_internal.h",
|
||||
"immediate_execution_context.h",
|
||||
"immediate_execution_distributed_manager.h",
|
||||
"immediate_execution_operation.h",
|
||||
"immediate_execution_tensor_handle.h",
|
||||
"tape.h",
|
||||
@ -584,6 +587,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_distributed_manager",
|
||||
hdrs = ["immediate_execution_distributed_manager.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_context",
|
||||
hdrs = ["immediate_execution_context.h"],
|
||||
@ -592,12 +608,14 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":immediate_execution_distributed_manager",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
|
@ -21,16 +21,11 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
@ -39,59 +34,39 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
// "tensorflow/core/platform/platform.h" must be included first before using
|
||||
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
@ -100,611 +75,6 @@ string DeviceName(const tensorflow::Device* d) {
|
||||
return (d == nullptr) ? "cpu:0" : d->name();
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
|
||||
const tensorflow::ServerDef& server_def) {
|
||||
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
|
||||
return false;
|
||||
}
|
||||
return server_def.default_session_config().SerializeAsString() ==
|
||||
context->session_options().config.SerializeAsString();
|
||||
}
|
||||
|
||||
tensorflow::Status AddRemoteDevicesToMgr(
|
||||
const std::vector<string>& added_remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
|
||||
tensorflow::mutex remote_devices_mu;
|
||||
int num_added_workers = added_remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_added_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_added_workers);
|
||||
for (int i = 0; i < num_added_workers; i++) {
|
||||
tensorflow::NewRemoteDevices(
|
||||
tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
|
||||
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
|
||||
const tensorflow::Status& s,
|
||||
std::vector<tensorflow::Device*>* devices) {
|
||||
statuses[i] = s;
|
||||
if (s.ok()) {
|
||||
tensorflow::mutex_lock l(remote_devices_mu);
|
||||
for (tensorflow::Device* d : *devices) {
|
||||
remote_devices.emplace_back(d);
|
||||
}
|
||||
}
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_added_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status GetAllRemoteDevices(
|
||||
const std::vector<string>& remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
|
||||
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
|
||||
TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
|
||||
remote_device_mgr.get()));
|
||||
*device_mgr = std::move(remote_device_mgr);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status RemoveRemoteDevicesFromMgr(
|
||||
const std::vector<string>& removed_remote_workers,
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
||||
const std::vector<tensorflow::Device*> remote_devices =
|
||||
(remote_device_mgr->ListDevices());
|
||||
std::vector<tensorflow::Device*> devices_to_remove;
|
||||
for (tensorflow::Device* d : remote_devices) {
|
||||
for (const string& remote_worker : removed_remote_workers) {
|
||||
if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
|
||||
d->name())) {
|
||||
devices_to_remove.emplace_back(d);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
|
||||
const string& local_worker,
|
||||
std::vector<string>* remote_workers) {
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(server);
|
||||
if (grpc_server == nullptr) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
|
||||
}
|
||||
grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
|
||||
remote_workers->erase(
|
||||
std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
|
||||
remote_workers->end());
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void DifferentiateWorkerLists(const std::vector<string>* current_list,
|
||||
const std::vector<string>* new_list,
|
||||
std::vector<string>* added,
|
||||
std::vector<string>* removed,
|
||||
std::vector<string>* existing) {
|
||||
// Get STL set_difference and set_intersection with one list traversal.
|
||||
// Similar to the set_difference library function, the input lists
|
||||
// (`current_list` and `new_list`) must be sorted before calling the function.
|
||||
added->resize(new_list->size());
|
||||
removed->resize(current_list->size());
|
||||
existing->resize(current_list->size());
|
||||
std::vector<string>::const_iterator curr_it = current_list->begin();
|
||||
std::vector<string>::const_iterator new_it = new_list->begin();
|
||||
std::vector<string>::iterator added_it = added->begin();
|
||||
std::vector<string>::iterator removed_it = removed->begin();
|
||||
std::vector<string>::iterator existing_it = existing->begin();
|
||||
while (curr_it != current_list->end() && new_it != new_list->end()) {
|
||||
if (*curr_it < *new_it) {
|
||||
*removed_it++ = *curr_it++;
|
||||
} else if (*curr_it > *new_it) {
|
||||
*added_it++ = *new_it++;
|
||||
} else {
|
||||
*existing_it++ = *curr_it++;
|
||||
new_it++;
|
||||
}
|
||||
}
|
||||
removed_it = std::copy(curr_it, current_list->end(), removed_it);
|
||||
added_it = std::copy(new_it, new_list->end(), added_it);
|
||||
added->resize(added_it - added->begin());
|
||||
removed->resize(removed_it - removed->begin());
|
||||
existing->resize(existing_it - existing->begin());
|
||||
}
|
||||
|
||||
tensorflow::Status GetReplacedFromExistingWorkers(
|
||||
const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* client_cache,
|
||||
std::vector<string>* replaced_workers) {
|
||||
tensorflow::BlockingCounter counter(existing_workers->size());
|
||||
std::vector<tensorflow::Status> statuses(existing_workers->size());
|
||||
tensorflow::eager::KeepAliveRequest request;
|
||||
request.set_context_id(context_id);
|
||||
std::vector<tensorflow::eager::KeepAliveResponse> responses(
|
||||
existing_workers->size());
|
||||
for (int i = 0; i < existing_workers->size(); i++) {
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] =
|
||||
client_cache->GetClient(existing_workers->at(i), &eager_client);
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
eager_client->KeepAliveAsync(
|
||||
&request, &responses[i],
|
||||
[i, &statuses, &counter](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < existing_workers->size(); i++) {
|
||||
// If the RPC fails (indicating that the requested ID doesn't exist on
|
||||
// remote), or the returned view ID is not equal to the local one
|
||||
// (indicating that the remote worker has a stale view of cluster), treat
|
||||
// the worker as replaced.
|
||||
if (!statuses[i].ok() ||
|
||||
responses[i].context_view_id() != context_view_id) {
|
||||
replaced_workers->emplace_back(existing_workers->at(i));
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status CreateRemoteContexts(
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const bool lazy_copy_remote_function_inputs,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
&parsed_name)) {
|
||||
statuses[i] = tensorflow::errors::InvalidArgument(
|
||||
"Unable to parse ", remote_worker, " as a device name");
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::eager::CreateContextRequest request;
|
||||
tensorflow::eager::CreateContextResponse* response =
|
||||
new tensorflow::eager::CreateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(),
|
||||
base_request.cluster_device_attributes_size());
|
||||
for (int i = 0; i < filtered_device_mask.size(); i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
request.set_lazy_copy_remote_function_inputs(
|
||||
lazy_copy_remote_function_inputs);
|
||||
|
||||
eager_client->CreateContextAsync(
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
tensorflow::StatusGroup sg;
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
if (TF_PREDICT_FALSE(!statuses[i].ok())) {
|
||||
sg.Update(statuses[i]);
|
||||
}
|
||||
}
|
||||
return sg.as_summary_status();
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateRemoteContexts(
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
const std::vector<string>& added_workers,
|
||||
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
|
||||
int cluster_device_count = base_request.cluster_device_attributes_size();
|
||||
std::unordered_set<string> added_or_removed(added_workers.begin(),
|
||||
added_workers.end());
|
||||
std::copy(removed_workers.begin(), removed_workers.end(),
|
||||
std::inserter(added_or_removed, added_or_removed.end()));
|
||||
// Whether each device is in the updated (added or removed) workers
|
||||
std::vector<bool> device_added_or_removed(cluster_device_count);
|
||||
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
|
||||
const auto& da = base_request.cluster_device_attributes().at(i);
|
||||
tensorflow::DeviceNameUtils::ParsedName pn;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
|
||||
string task_name;
|
||||
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
|
||||
if (added_or_removed.find(task_name) != added_or_removed.end()) {
|
||||
device_added_or_removed[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
&parsed_name)) {
|
||||
statuses[i] = tensorflow::errors::InvalidArgument(
|
||||
"Unable to parse ", remote_worker, " as a device name");
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
|
||||
|
||||
// If any of the devices that match the device filters are in the set of
|
||||
// added or removed workers, we must send a complete UpdateContextRequest.
|
||||
// Otherwise, only send a simple request to increment context view ID.
|
||||
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
|
||||
std::transform(device_added_or_removed.begin(),
|
||||
device_added_or_removed.end(), filtered_device_mask.begin(),
|
||||
added_or_removed_filtered_devices.begin(),
|
||||
std::logical_and<bool>());
|
||||
const bool full_update_request =
|
||||
std::accumulate(added_or_removed_filtered_devices.begin(),
|
||||
added_or_removed_filtered_devices.end(), false,
|
||||
std::logical_or<bool>());
|
||||
|
||||
tensorflow::eager::UpdateContextRequest request;
|
||||
auto* response = new tensorflow::eager::UpdateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
if (full_update_request) {
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
for (int i = 0; i < cluster_device_count; i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eager_client->UpdateContextAsync(
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
TFE_Context* ctx, bool reset_context) {
|
||||
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
||||
// server object (which currently CHECK-fails) and we miss the error, instead,
|
||||
// we log the error, and then return to allow the user to see the error
|
||||
// message.
|
||||
#define LOG_AND_RETURN_IF_ERROR(...) \
|
||||
do { \
|
||||
const ::tensorflow::Status _status = (__VA_ARGS__); \
|
||||
if (TF_PREDICT_FALSE(!_status.ok())) { \
|
||||
LOG(ERROR) << _status.error_message(); \
|
||||
return _status; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
string worker_name =
|
||||
tensorflow::strings::StrCat("/job:", server_def.job_name(),
|
||||
"/replica:0/task:", server_def.task_index());
|
||||
|
||||
// List of current remote workers before updating server_def. Unused if
|
||||
// resetting the server_def.
|
||||
std::vector<string> curr_remote_workers;
|
||||
// List of updated remote workers.
|
||||
std::vector<string> remote_workers;
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
const tensorflow::DeviceMgr* device_mgr =
|
||||
AreLocalDevicesCompatible(context, server_def)
|
||||
? context->local_device_mgr()
|
||||
: nullptr;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
|
||||
server_def, {device_mgr}, &new_server));
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||
&curr_remote_workers));
|
||||
// No need to check the cast here, since `ListRemoteWorkers` already checks
|
||||
// if the server is a GRPC server or not.
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
}
|
||||
|
||||
tensorflow::uint64 context_id = context->GetContextId();
|
||||
tensorflow::uint64 context_view_id = context->GetContextViewId();
|
||||
if (reset_context) {
|
||||
context_id = tensorflow::EagerContext::NewContextId();
|
||||
context_view_id = 0;
|
||||
// Make master eager context accessible by local eager service, which might
|
||||
// receive send tensor requests from remote workers.
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
&remote_eager_workers));
|
||||
|
||||
// For cluster update, use a status group to aggregate statuses from
|
||||
// * adding and removing remote devices
|
||||
// * creating remote contexts on newly added workers
|
||||
// * updating remote contexts on existing workers
|
||||
// * updating the master context
|
||||
// Note that we should not return immediately on errors in the middle of these
|
||||
// updates to prevent cluster from having inconsistent context views.
|
||||
//
|
||||
// Unused if `reset_context` is True.
|
||||
tensorflow::StatusGroup sg;
|
||||
|
||||
// When updating an existing context, populate the following lists with:
|
||||
// * added_workers: set(remote_workers) - set(curr_remote_workers)
|
||||
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
|
||||
// * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
|
||||
// * replaced_workers: workers with the same task names and potentially the
|
||||
// same `hostname:port`s, but replaced by different processes
|
||||
std::vector<string> added_workers;
|
||||
std::vector<string> removed_workers;
|
||||
std::vector<string> existing_workers;
|
||||
std::vector<string> replaced_workers;
|
||||
|
||||
// New remote device manager created for new server_def. Unused if updating
|
||||
// server_def.
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
|
||||
remote_workers, grpc_server->master_env()->worker_cache,
|
||||
&new_remote_device_mgr));
|
||||
remote_device_mgr = new_remote_device_mgr.get();
|
||||
} else {
|
||||
context->ClearCachesAndDefaultExecutor();
|
||||
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
||||
// tensors for removed / replaced workers.
|
||||
|
||||
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
|
||||
if (remote_device_mgr == nullptr) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
|
||||
"Updating context with an invalid set of remote devices."));
|
||||
}
|
||||
std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
|
||||
std::sort(remote_workers.begin(), remote_workers.end());
|
||||
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
|
||||
&added_workers, &removed_workers,
|
||||
&existing_workers);
|
||||
sg.Update(GetReplacedFromExistingWorkers(
|
||||
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
||||
remote_eager_workers.get(), &replaced_workers));
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "Updating cluster with following changes";
|
||||
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
|
||||
for (const string& w : removed_workers)
|
||||
VLOG(1) << " Removed worker " << w;
|
||||
for (const string& w : replaced_workers)
|
||||
VLOG(1) << " Replaced worker " << w;
|
||||
}
|
||||
if (!replaced_workers.empty()) {
|
||||
// Treat replaced workers as removed then added back, so that we recreate
|
||||
// remote devices and contexts, and re-register functions on those workers
|
||||
removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
|
||||
replaced_workers.end());
|
||||
added_workers.insert(added_workers.end(), replaced_workers.begin(),
|
||||
replaced_workers.end());
|
||||
for (const string& w : replaced_workers) {
|
||||
existing_workers.erase(
|
||||
std::remove(existing_workers.begin(), existing_workers.end(), w),
|
||||
existing_workers.end());
|
||||
}
|
||||
}
|
||||
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
||||
sg.Update(AddRemoteDevicesToMgr(added_workers,
|
||||
grpc_server->master_env()->worker_cache,
|
||||
remote_device_mgr));
|
||||
}
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
|
||||
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
|
||||
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
||||
&local_device_attributes);
|
||||
|
||||
// This request make sure that we can create Rendezvous properly between
|
||||
// Local and Remote context.
|
||||
tensorflow::eager::CreateContextRequest base_request;
|
||||
for (const auto& da : cluster_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
for (const auto& da : local_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
|
||||
// Initialize remote eager workers.
|
||||
if (reset_context) {
|
||||
const tensorflow::Status s = CreateRemoteContexts(
|
||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request);
|
||||
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
|
||||
// the CreateRemoteContexts to fail. We currently only log instead of
|
||||
// directly returning the error, since returning here will cause the server
|
||||
// object to be destroyed (which currently CHECK-fails). The client will
|
||||
// see additional errors if ops are subsequently sent to the failed workers.
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
LOG(ERROR) << "Error when creating contexts on remote targets: "
|
||||
<< s.error_message()
|
||||
<< "\nExecuting remote ops or functions on these remote "
|
||||
"targets will fail.";
|
||||
}
|
||||
} else {
|
||||
if (sg.ok()) {
|
||||
// Create remote contexts on the newly added workers only if the master
|
||||
// has collected all device information from them (i.e., the
|
||||
// GetAllRemoteDevices call returns succussfully). Note that in rare cases
|
||||
// GetAllRemoteDevices can still fail even with RPCs configured to wait
|
||||
// until the remote workers to become alive. If the master creates remote
|
||||
// contexts on the workers whose devices are still not collected, those
|
||||
// workers will be treated as existing workers subsequently, so the master
|
||||
// will never get devices from them even with retrying UpdateServerDef.
|
||||
sg.Update(CreateRemoteContexts(
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
}
|
||||
if (!existing_workers.empty()) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (const string& w : existing_workers) {
|
||||
VLOG(1) << "Updating cluster with existing worker " << w;
|
||||
}
|
||||
}
|
||||
// The master's context_view_id will be incremented by one in the
|
||||
// UpdateRemoteMaster call later. We want existing workers to also have
|
||||
// the updated context_view_id, so we must set their context_view_id to
|
||||
// the master's current context_view_id + 1.
|
||||
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
|
||||
removed_workers, context_id,
|
||||
context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), base_request));
|
||||
}
|
||||
}
|
||||
|
||||
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
|
||||
if (reset_context) {
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
|
||||
// Initialize remote tensor communication based on worker session.
|
||||
LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
|
||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||
worker_session.get());
|
||||
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
|
||||
/*is_master=*/true, context);
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
|
||||
std::move(new_server), grpc_server->worker_env(), worker_session,
|
||||
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
|
||||
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
|
||||
std::move(remote_mgr)));
|
||||
|
||||
// NOTE: We start the server after all other initialization, because the
|
||||
// GrpcServer cannot be destroyed after it is started.
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
} else {
|
||||
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
/*isolate_session_state=*/true));
|
||||
sg.Update(context->UpdateRemoteMaster(context_id,
|
||||
std::move(remote_eager_workers),
|
||||
added_workers, removed_workers));
|
||||
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
|
||||
}
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
@ -735,7 +105,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
#endif
|
||||
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
||||
}
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
status->status = tensorflow::DeviceFactory::AddDevices(
|
||||
@ -747,13 +117,18 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r));
|
||||
/*device_mgr_owned*/ true, r);
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
eager_context->SetDistributedManager(
|
||||
std::make_unique<tensorflow::EagerContextDistributedManager>(
|
||||
eager_context));
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return tensorflow::wrap(eager_context);
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
@ -791,26 +166,9 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
return;
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
|
||||
device_filters[i] = tdf.second.device_filters(i);
|
||||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
}
|
||||
}
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/true);
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
|
||||
server_def, /*reset_context=*/true, keep_alive_secs);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -835,14 +193,9 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to update a context with invalid context id.");
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
LOG(WARNING) << "Device filters can only be specified when initializing "
|
||||
"the cluster. Any changes in device filters are ignored "
|
||||
"when updating the server def.";
|
||||
}
|
||||
// TODO(haoyuzhang): Check server_def compatibility before the update
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/false);
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
|
||||
server_def, /*reset_context=*/false, keep_alive_secs);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -854,44 +207,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
"TFE_ContextSetServerDef not supported on mobile");
|
||||
return false;
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
status->status =
|
||||
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
|
||||
return false;
|
||||
}
|
||||
tensorflow::WorkerInterface* wi =
|
||||
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
|
||||
if (wi == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unable to find worker interface corresponding to task ", worker_name);
|
||||
return false;
|
||||
}
|
||||
|
||||
tensorflow::GetStatusRequest request;
|
||||
tensorflow::GetStatusResponse response;
|
||||
tensorflow::Status remote_status;
|
||||
tensorflow::Notification done;
|
||||
wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
|
||||
[&remote_status, &done](const tensorflow::Status& s) {
|
||||
remote_status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
|
||||
// We set OK status so the call does not raise any exceptions. Instead, caller
|
||||
// users the return value to tell if the remote worker is alive.
|
||||
status->status = tensorflow::Status::OK();
|
||||
|
||||
if (remote_status.ok()) {
|
||||
return true;
|
||||
}
|
||||
LOG(INFO) << "Remote worker " << worker_name
|
||||
<< " is not alive: " << remote_status.error_message();
|
||||
return false;
|
||||
bool is_alive;
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
|
||||
worker_name, &is_alive);
|
||||
return is_alive;
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
@ -174,6 +176,18 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
|
||||
ImmediateExecutionTensorHandle* handle) = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Distributed runtime related functions.
|
||||
//===--------------------------------------------------------------------===//
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// Set a distributed manager that helps set up, update, and check liveness
|
||||
// of member tasks in the cluster.
|
||||
virtual void SetDistributedManager(
|
||||
std::unique_ptr<ImmediateExecutionDistributedManager> distributed) = 0;
|
||||
|
||||
virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0;
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
protected:
|
||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||
: AbstractContext(kind) {}
|
||||
|
45
tensorflow/c/eager/immediate_execution_distributed_manager.h
Normal file
45
tensorflow/c/eager/immediate_execution_distributed_manager.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
||||
#define TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
||||
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class ImmediateExecutionContext;
|
||||
|
||||
class ImmediateExecutionDistributedManager {
|
||||
public:
|
||||
virtual ~ImmediateExecutionDistributedManager() {}
|
||||
|
||||
// Set up distributed execution environment on local and remote tasks.
|
||||
// When `reset_context` is true, initialize new cluster context state based on
|
||||
// cluster configurations provided in `server_def`; otherwise, update existing
|
||||
// context state with the provided `server_def`.
|
||||
// Contexts created on remote tasks will be considered stale and garbage
|
||||
// collected after `keep_alive_secs` of inactivity.
|
||||
virtual Status SetOrUpdateServerDef(const ServerDef& server_def,
|
||||
bool reset_context,
|
||||
int keep_alive_secs) = 0;
|
||||
|
||||
// Check if the remote task is alive.
|
||||
virtual Status CheckRemoteAlive(const std::string& remote_task_name,
|
||||
bool* is_alive) = 0;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
@ -80,6 +80,7 @@ tf_cuda_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_distributed_manager",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
@ -109,6 +110,47 @@ tf_cuda_library(
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "context_distributed_manager",
|
||||
srcs = [
|
||||
"context_distributed_manager.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"context_distributed_manager.h",
|
||||
],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
":eager_executor",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_distributed_manager",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:session_options",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_session",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
"//tensorflow/core/distributed_runtime/eager:cluster_function_library_runtime",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/eager:remote_mgr",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "context_test",
|
||||
srcs = ["context_test.cc"],
|
||||
|
@ -24,15 +24,8 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
// clang-format off
|
||||
// Required for IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
@ -40,12 +33,34 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
// "tensorflow/core/platform/platform.h" must be included first before using
|
||||
// IS_MOBILE_PLATFORM.
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||
@ -53,23 +68,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -442,6 +440,15 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
return ptr->getKind() == kEager;
|
||||
}
|
||||
|
||||
// Function to support distributed C API.
|
||||
void SetDistributedManager(
|
||||
std::unique_ptr<ImmediateExecutionDistributedManager> distributed)
|
||||
override {
|
||||
distributed_manager_ = std::move(distributed);
|
||||
}
|
||||
ImmediateExecutionDistributedManager* GetDistributedManager() override {
|
||||
return distributed_manager_.get();
|
||||
}
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
// Closes remote eager contexts, waits for all RPCs to finish, and
|
||||
@ -671,6 +678,10 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>>
|
||||
cluster_device_filters_ TF_GUARDED_BY(remote_state_mu_);
|
||||
|
||||
// A distributed manager that helps setup, update, and check liveness of
|
||||
// member tasks in the cluster.
|
||||
std::unique_ptr<ImmediateExecutionDistributedManager> distributed_manager_;
|
||||
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
// For a multi device function, the target device of each input is unknown
|
||||
|
@ -0,0 +1,716 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
namespace tensorflow {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
namespace {
|
||||
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
|
||||
const tensorflow::ServerDef& server_def) {
|
||||
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
|
||||
return false;
|
||||
}
|
||||
return server_def.default_session_config().SerializeAsString() ==
|
||||
context->session_options().config.SerializeAsString();
|
||||
}
|
||||
|
||||
tensorflow::Status AddRemoteDevicesToMgr(
|
||||
const std::vector<string>& added_remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
|
||||
tensorflow::mutex remote_devices_mu;
|
||||
int num_added_workers = added_remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_added_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_added_workers);
|
||||
for (int i = 0; i < num_added_workers; i++) {
|
||||
tensorflow::NewRemoteDevices(
|
||||
tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
|
||||
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
|
||||
const tensorflow::Status& s,
|
||||
std::vector<tensorflow::Device*>* devices) {
|
||||
statuses[i] = s;
|
||||
if (s.ok()) {
|
||||
tensorflow::mutex_lock l(remote_devices_mu);
|
||||
for (tensorflow::Device* d : *devices) {
|
||||
remote_devices.emplace_back(d);
|
||||
}
|
||||
}
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_added_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status GetAllRemoteDevices(
|
||||
const std::vector<string>& remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
|
||||
auto remote_device_mgr = std::make_unique<tensorflow::DynamicDeviceMgr>();
|
||||
TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
|
||||
remote_device_mgr.get()));
|
||||
*device_mgr = std::move(remote_device_mgr);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status RemoveRemoteDevicesFromMgr(
|
||||
const std::vector<string>& removed_remote_workers,
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
||||
const std::vector<tensorflow::Device*> remote_devices =
|
||||
(remote_device_mgr->ListDevices());
|
||||
std::vector<tensorflow::Device*> devices_to_remove;
|
||||
for (tensorflow::Device* d : remote_devices) {
|
||||
for (const string& remote_worker : removed_remote_workers) {
|
||||
if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
|
||||
d->name())) {
|
||||
devices_to_remove.emplace_back(d);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
|
||||
const string& local_worker,
|
||||
std::vector<string>* remote_workers) {
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(server);
|
||||
if (grpc_server == nullptr) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
|
||||
}
|
||||
grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
|
||||
remote_workers->erase(
|
||||
std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
|
||||
remote_workers->end());
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void DifferentiateWorkerLists(const std::vector<string>* current_list,
|
||||
const std::vector<string>* new_list,
|
||||
std::vector<string>* added,
|
||||
std::vector<string>* removed,
|
||||
std::vector<string>* existing) {
|
||||
// Get STL set_difference and set_intersection with one list traversal.
|
||||
// Similar to the set_difference library function, the input lists
|
||||
// (`current_list` and `new_list`) must be sorted before calling the function.
|
||||
added->resize(new_list->size());
|
||||
removed->resize(current_list->size());
|
||||
existing->resize(current_list->size());
|
||||
std::vector<string>::const_iterator curr_it = current_list->begin();
|
||||
std::vector<string>::const_iterator new_it = new_list->begin();
|
||||
std::vector<string>::iterator added_it = added->begin();
|
||||
std::vector<string>::iterator removed_it = removed->begin();
|
||||
std::vector<string>::iterator existing_it = existing->begin();
|
||||
while (curr_it != current_list->end() && new_it != new_list->end()) {
|
||||
if (*curr_it < *new_it) {
|
||||
*removed_it++ = *curr_it++;
|
||||
} else if (*curr_it > *new_it) {
|
||||
*added_it++ = *new_it++;
|
||||
} else {
|
||||
*existing_it++ = *curr_it++;
|
||||
new_it++;
|
||||
}
|
||||
}
|
||||
removed_it = std::copy(curr_it, current_list->end(), removed_it);
|
||||
added_it = std::copy(new_it, new_list->end(), added_it);
|
||||
added->resize(added_it - added->begin());
|
||||
removed->resize(removed_it - removed->begin());
|
||||
existing->resize(existing_it - existing->begin());
|
||||
}
|
||||
|
||||
tensorflow::Status GetReplacedFromExistingWorkers(
|
||||
const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* client_cache,
|
||||
std::vector<string>* replaced_workers) {
|
||||
tensorflow::BlockingCounter counter(existing_workers->size());
|
||||
std::vector<tensorflow::Status> statuses(existing_workers->size());
|
||||
tensorflow::eager::KeepAliveRequest request;
|
||||
request.set_context_id(context_id);
|
||||
std::vector<tensorflow::eager::KeepAliveResponse> responses(
|
||||
existing_workers->size());
|
||||
for (int i = 0; i < existing_workers->size(); i++) {
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] =
|
||||
client_cache->GetClient(existing_workers->at(i), &eager_client);
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
eager_client->KeepAliveAsync(
|
||||
&request, &responses[i],
|
||||
[i, &statuses, &counter](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < existing_workers->size(); i++) {
|
||||
// If the RPC fails (indicating that the requested ID doesn't exist on
|
||||
// remote), or the returned view ID is not equal to the local one
|
||||
// (indicating that the remote worker has a stale view of cluster), treat
|
||||
// the worker as replaced.
|
||||
if (!statuses[i].ok() ||
|
||||
responses[i].context_view_id() != context_view_id) {
|
||||
replaced_workers->emplace_back(existing_workers->at(i));
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status CreateRemoteContexts(
|
||||
EagerContext* context, const std::vector<string>& remote_workers,
|
||||
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const bool lazy_copy_remote_function_inputs,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
&parsed_name)) {
|
||||
statuses[i] = tensorflow::errors::InvalidArgument(
|
||||
"Unable to parse ", remote_worker, " as a device name");
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::eager::CreateContextRequest request;
|
||||
tensorflow::eager::CreateContextResponse* response =
|
||||
new tensorflow::eager::CreateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(),
|
||||
base_request.cluster_device_attributes_size());
|
||||
for (int i = 0; i < filtered_device_mask.size(); i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
request.set_lazy_copy_remote_function_inputs(
|
||||
lazy_copy_remote_function_inputs);
|
||||
|
||||
eager_client->CreateContextAsync(
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
tensorflow::StatusGroup sg;
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
if (TF_PREDICT_FALSE(!statuses[i].ok())) {
|
||||
sg.Update(statuses[i]);
|
||||
}
|
||||
}
|
||||
return sg.as_summary_status();
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateRemoteContexts(
|
||||
EagerContext* context, const std::vector<string>& remote_workers,
|
||||
const std::vector<string>& added_workers,
|
||||
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
|
||||
int cluster_device_count = base_request.cluster_device_attributes_size();
|
||||
std::unordered_set<string> added_or_removed(added_workers.begin(),
|
||||
added_workers.end());
|
||||
std::copy(removed_workers.begin(), removed_workers.end(),
|
||||
std::inserter(added_or_removed, added_or_removed.end()));
|
||||
// Whether each device is in the updated (added or removed) workers
|
||||
std::vector<bool> device_added_or_removed(cluster_device_count);
|
||||
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
|
||||
const auto& da = base_request.cluster_device_attributes().at(i);
|
||||
tensorflow::DeviceNameUtils::ParsedName pn;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
|
||||
string task_name;
|
||||
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
|
||||
if (added_or_removed.find(task_name) != added_or_removed.end()) {
|
||||
device_added_or_removed[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
&parsed_name)) {
|
||||
statuses[i] = tensorflow::errors::InvalidArgument(
|
||||
"Unable to parse ", remote_worker, " as a device name");
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
|
||||
|
||||
// If any of the devices that match the device filters are in the set of
|
||||
// added or removed workers, we must send a complete UpdateContextRequest.
|
||||
// Otherwise, only send a simple request to increment context view ID.
|
||||
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
|
||||
std::transform(device_added_or_removed.begin(),
|
||||
device_added_or_removed.end(), filtered_device_mask.begin(),
|
||||
added_or_removed_filtered_devices.begin(),
|
||||
std::logical_and<bool>());
|
||||
const bool full_update_request =
|
||||
std::accumulate(added_or_removed_filtered_devices.begin(),
|
||||
added_or_removed_filtered_devices.end(), false,
|
||||
std::logical_or<bool>());
|
||||
|
||||
tensorflow::eager::UpdateContextRequest request;
|
||||
auto* response = new tensorflow::eager::UpdateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
if (full_update_request) {
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
for (int i = 0; i < cluster_device_count; i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eager_client->UpdateContextAsync(
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateContextWithServerDef(
|
||||
EagerContext* context, const tensorflow::ServerDef& server_def,
|
||||
bool reset_context, int keep_alive_secs) {
|
||||
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
||||
// server object (which currently CHECK-fails) and we miss the error, instead,
|
||||
// we log the error, and then return to allow the user to see the error
|
||||
// message.
|
||||
#define LOG_AND_RETURN_IF_ERROR(...) \
|
||||
do { \
|
||||
const ::tensorflow::Status _status = (__VA_ARGS__); \
|
||||
if (TF_PREDICT_FALSE(!_status.ok())) { \
|
||||
LOG(ERROR) << _status.error_message(); \
|
||||
return _status; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
string worker_name =
|
||||
tensorflow::strings::StrCat("/job:", server_def.job_name(),
|
||||
"/replica:0/task:", server_def.task_index());
|
||||
|
||||
// List of current remote workers before updating server_def. Unused if
|
||||
// resetting the server_def.
|
||||
std::vector<string> curr_remote_workers;
|
||||
// List of updated remote workers.
|
||||
std::vector<string> remote_workers;
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
const tensorflow::DeviceMgr* device_mgr =
|
||||
AreLocalDevicesCompatible(context, server_def)
|
||||
? context->local_device_mgr()
|
||||
: nullptr;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
|
||||
server_def, {device_mgr}, &new_server));
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||
&curr_remote_workers));
|
||||
// No need to check the cast here, since `ListRemoteWorkers` already checks
|
||||
// if the server is a GRPC server or not.
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
}
|
||||
|
||||
tensorflow::uint64 context_id = context->GetContextId();
|
||||
tensorflow::uint64 context_view_id = context->GetContextViewId();
|
||||
if (reset_context) {
|
||||
context_id = tensorflow::EagerContext::NewContextId();
|
||||
context_view_id = 0;
|
||||
// Make master eager context accessible by local eager service, which might
|
||||
// receive send tensor requests from remote workers.
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
&remote_eager_workers));
|
||||
|
||||
// For cluster update, use a status group to aggregate statuses from
|
||||
// * adding and removing remote devices
|
||||
// * creating remote contexts on newly added workers
|
||||
// * updating remote contexts on existing workers
|
||||
// * updating the master context
|
||||
// Note that we should not return immediately on errors in the middle of these
|
||||
// updates to prevent cluster from having inconsistent context views.
|
||||
//
|
||||
// Unused if `reset_context` is True.
|
||||
tensorflow::StatusGroup sg;
|
||||
|
||||
// When updating an existing context, populate the following lists with:
|
||||
// * added_workers: set(remote_workers) - set(curr_remote_workers)
|
||||
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
|
||||
// * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
|
||||
// * replaced_workers: workers with the same task names and potentially the
|
||||
// same `hostname:port`s, but replaced by different processes
|
||||
std::vector<string> added_workers;
|
||||
std::vector<string> removed_workers;
|
||||
std::vector<string> existing_workers;
|
||||
std::vector<string> replaced_workers;
|
||||
|
||||
// New remote device manager created for new server_def. Unused if updating
|
||||
// server_def.
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
|
||||
remote_workers, grpc_server->master_env()->worker_cache,
|
||||
&new_remote_device_mgr));
|
||||
remote_device_mgr = new_remote_device_mgr.get();
|
||||
} else {
|
||||
context->ClearCachesAndDefaultExecutor();
|
||||
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
||||
// tensors for removed / replaced workers.
|
||||
|
||||
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
|
||||
if (remote_device_mgr == nullptr) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
|
||||
"Updating context with an invalid set of remote devices."));
|
||||
}
|
||||
std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
|
||||
std::sort(remote_workers.begin(), remote_workers.end());
|
||||
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
|
||||
&added_workers, &removed_workers,
|
||||
&existing_workers);
|
||||
sg.Update(GetReplacedFromExistingWorkers(
|
||||
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
||||
remote_eager_workers.get(), &replaced_workers));
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "Updating cluster with following changes";
|
||||
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
|
||||
for (const string& w : removed_workers)
|
||||
VLOG(1) << " Removed worker " << w;
|
||||
for (const string& w : replaced_workers)
|
||||
VLOG(1) << " Replaced worker " << w;
|
||||
}
|
||||
if (!replaced_workers.empty()) {
|
||||
// Treat replaced workers as removed then added back, so that we recreate
|
||||
// remote devices and contexts, and re-register functions on those workers
|
||||
removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
|
||||
replaced_workers.end());
|
||||
added_workers.insert(added_workers.end(), replaced_workers.begin(),
|
||||
replaced_workers.end());
|
||||
for (const string& w : replaced_workers) {
|
||||
existing_workers.erase(
|
||||
std::remove(existing_workers.begin(), existing_workers.end(), w),
|
||||
existing_workers.end());
|
||||
}
|
||||
}
|
||||
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
||||
sg.Update(AddRemoteDevicesToMgr(added_workers,
|
||||
grpc_server->master_env()->worker_cache,
|
||||
remote_device_mgr));
|
||||
}
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
|
||||
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
|
||||
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
||||
&local_device_attributes);
|
||||
|
||||
// This request make sure that we can create Rendezvous properly between
|
||||
// Local and Remote context.
|
||||
tensorflow::eager::CreateContextRequest base_request;
|
||||
for (const auto& da : cluster_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
for (const auto& da : local_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
|
||||
// Initialize remote eager workers.
|
||||
if (reset_context) {
|
||||
const tensorflow::Status s = CreateRemoteContexts(
|
||||
context, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request);
|
||||
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
|
||||
// the CreateRemoteContexts to fail. We currently only log instead of
|
||||
// directly returning the error, since returning here will cause the server
|
||||
// object to be destroyed (which currently CHECK-fails). The client will
|
||||
// see additional errors if ops are subsequently sent to the failed workers.
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
LOG(ERROR) << "Error when creating contexts on remote targets: "
|
||||
<< s.error_message()
|
||||
<< "\nExecuting remote ops or functions on these remote "
|
||||
"targets will fail.";
|
||||
}
|
||||
} else {
|
||||
if (sg.ok()) {
|
||||
// Create remote contexts on the newly added workers only if the master
|
||||
// has collected all device information from them (i.e., the
|
||||
// GetAllRemoteDevices call returns succussfully). Note that in rare cases
|
||||
// GetAllRemoteDevices can still fail even with RPCs configured to wait
|
||||
// until the remote workers to become alive. If the master creates remote
|
||||
// contexts on the workers whose devices are still not collected, those
|
||||
// workers will be treated as existing workers subsequently, so the master
|
||||
// will never get devices from them even with retrying UpdateServerDef.
|
||||
sg.Update(CreateRemoteContexts(
|
||||
context, added_workers, context_id, context_view_id + 1,
|
||||
keep_alive_secs, server_def, remote_eager_workers.get(),
|
||||
context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(),
|
||||
base_request));
|
||||
}
|
||||
if (!existing_workers.empty()) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (const string& w : existing_workers) {
|
||||
VLOG(1) << "Updating cluster with existing worker " << w;
|
||||
}
|
||||
}
|
||||
// The master's context_view_id will be incremented by one in the
|
||||
// UpdateRemoteMaster call later. We want existing workers to also have
|
||||
// the updated context_view_id, so we must set their context_view_id to
|
||||
// the master's current context_view_id + 1.
|
||||
sg.Update(UpdateRemoteContexts(context, existing_workers, added_workers,
|
||||
removed_workers, context_id,
|
||||
context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), base_request));
|
||||
}
|
||||
}
|
||||
|
||||
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
|
||||
if (reset_context) {
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
|
||||
// Initialize remote tensor communication based on worker session.
|
||||
LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
|
||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||
worker_session.get());
|
||||
auto remote_mgr = std::make_unique<tensorflow::eager::RemoteMgr>(
|
||||
/*is_master=*/true, context);
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
|
||||
std::move(new_server), grpc_server->worker_env(), worker_session,
|
||||
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
|
||||
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
|
||||
std::move(remote_mgr)));
|
||||
|
||||
// NOTE: We start the server after all other initialization, because the
|
||||
// GrpcServer cannot be destroyed after it is started.
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
} else {
|
||||
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
/*isolate_session_state=*/true));
|
||||
sg.Update(context->UpdateRemoteMaster(context_id,
|
||||
std::move(remote_eager_workers),
|
||||
added_workers, removed_workers));
|
||||
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
|
||||
}
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status EagerContextDistributedManager::SetOrUpdateServerDef(
|
||||
const ServerDef& server_def, bool reset_context, int keep_alive_secs) {
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
if (reset_context) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
|
||||
device_filters[i] = tdf.second.device_filters(i);
|
||||
}
|
||||
const string remote_worker =
|
||||
strings::StrCat(remote_prefix, task_index);
|
||||
TF_RETURN_IF_ERROR(
|
||||
context_->SetRemoteDeviceFilters(remote_worker, device_filters));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG(WARNING) << "Device filters can only be specified when initializing "
|
||||
"the cluster. Any changes in device filters are ignored "
|
||||
"when updating the server def.";
|
||||
}
|
||||
}
|
||||
return UpdateContextWithServerDef(context_, server_def, reset_context,
|
||||
keep_alive_secs);
|
||||
}
|
||||
|
||||
Status EagerContextDistributedManager::CheckRemoteAlive(
|
||||
const std::string& remote_task_name, bool* is_alive) {
|
||||
*is_alive = false;
|
||||
GrpcServer* grpc_server = dynamic_cast<GrpcServer*>(context_->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
return errors::Internal("Failed to get eager-compatible server instance.");
|
||||
}
|
||||
WorkerInterface* wi =
|
||||
grpc_server->master_env()->worker_cache->GetOrCreateWorker(
|
||||
remote_task_name);
|
||||
if (wi == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Unable to find worker interface corresponding to task ",
|
||||
remote_task_name);
|
||||
}
|
||||
|
||||
GetStatusRequest request;
|
||||
GetStatusResponse response;
|
||||
Status remote_status;
|
||||
Notification done;
|
||||
wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
|
||||
[&remote_status, &done](const Status& s) {
|
||||
remote_status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
|
||||
if (remote_status.ok()) {
|
||||
*is_alive = true;
|
||||
}
|
||||
LOG(INFO) << "Remote worker " << remote_task_name
|
||||
<< " is not alive: " << remote_status.error_message();
|
||||
return Status::OK();
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
} // namespace tensorflow
|
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_DISTRIBUTED_MANAGER_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_DISTRIBUTED_MANAGER_H_
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class EagerContext;
|
||||
|
||||
class EagerContextDistributedManager
|
||||
: public ImmediateExecutionDistributedManager {
|
||||
public:
|
||||
explicit EagerContextDistributedManager(EagerContext* context)
|
||||
: context_(context) {}
|
||||
|
||||
Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context,
|
||||
int keep_alive_secs) override;
|
||||
|
||||
Status CheckRemoteAlive(const std::string& remote_task_name,
|
||||
bool* is_alive) override;
|
||||
|
||||
private:
|
||||
EagerContext* context_;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_DISTRIBUTED_MANAGER_H_
|
Loading…
x
Reference in New Issue
Block a user