diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 5c085a4b9b8..08dd5d0820e 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -51,6 +51,7 @@ tf_cuda_library( ":immediate_execution_context", ":immediate_execution_operation", ":immediate_execution_tensor_handle", + ":immediate_execution_distributed_manager", ":abstract_tensor_handle", ":tfe_context_internal", ":tfe_cancellation_manager_internal", @@ -70,6 +71,7 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:context_distributed_manager", "//tensorflow/core/common_runtime/eager:core", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", @@ -119,6 +121,7 @@ filegroup( "gradients.h", "gradients_internal.h", "immediate_execution_context.h", + "immediate_execution_distributed_manager.h", "immediate_execution_operation.h", "immediate_execution_tensor_handle.h", "tape.h", @@ -584,6 +587,19 @@ cc_library( ], ) +cc_library( + name = "immediate_execution_distributed_manager", + hdrs = ["immediate_execution_distributed_manager.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "immediate_execution_context", hdrs = ["immediate_execution_context.h"], @@ -592,12 +608,14 @@ cc_library( ], deps = [ ":abstract_context", + ":immediate_execution_distributed_manager", ":immediate_execution_operation", ":immediate_execution_tensor_handle", "//tensorflow/c:tensor_interface", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 9c73d1aba8c..88705cf3058 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -21,16 +21,11 @@ limitations under the License. #include #include -#include "tensorflow/c/eager/abstract_tensor_handle.h" - -// clang-format off -#include "tensorflow/core/platform/platform.h" -// clang-format on - #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/immediate_execution_operation.h" @@ -39,59 +34,39 @@ limitations under the License. #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_tensor_internal.h" -#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) -#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" -#endif -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/eager/context.h" -#include "tensorflow/core/framework/device_attributes.pb.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/protobuf/device_filters.pb.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/common_runtime/copy_tensor.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" -#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" +#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/rendezvous_mgr.h" -#if !defined(IS_MOBILE_PLATFORM) -#include "tensorflow/core/distributed_runtime/eager/eager_client.h" -#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" -#include "tensorflow/core/distributed_runtime/remote_device.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" -#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" -#include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/core/distributed_runtime/worker_interface.h" -#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" -#endif // !IS_MOBILE_PLATFORM +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/casts.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/notification.h" -#include "tensorflow/core/platform/random.h" -#include "tensorflow/core/platform/refcount.h" -#include "tensorflow/core/platform/stringpiece.h" -#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/public/version.h" +// "tensorflow/core/platform/platform.h" must be included first before using +// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc. +#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) +#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" +#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE + +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h" +#endif // !IS_MOBILE_PLATFORM + using tensorflow::string; namespace { @@ -100,611 +75,6 @@ string DeviceName(const tensorflow::Device* d) { return (d == nullptr) ? "cpu:0" : d->name(); } -#if !defined(IS_MOBILE_PLATFORM) -bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context, - const tensorflow::ServerDef& server_def) { - if (server_def.job_name() != context->HostCPU()->parsed_name().job) { - return false; - } - return server_def.default_session_config().SerializeAsString() == - context->session_options().config.SerializeAsString(); -} - -tensorflow::Status AddRemoteDevicesToMgr( - const std::vector& added_remote_workers, - tensorflow::WorkerCacheInterface* worker_cache, - tensorflow::DynamicDeviceMgr* remote_device_mgr) { - std::vector> remote_devices; - tensorflow::mutex remote_devices_mu; - int num_added_workers = added_remote_workers.size(); - tensorflow::BlockingCounter counter(num_added_workers); - std::vector statuses(num_added_workers); - for (int i = 0; i < num_added_workers; i++) { - tensorflow::NewRemoteDevices( - tensorflow::Env::Default(), worker_cache, added_remote_workers[i], - [i, &statuses, &counter, &remote_devices, &remote_devices_mu]( - const tensorflow::Status& s, - std::vector* devices) { - statuses[i] = s; - if (s.ok()) { - tensorflow::mutex_lock l(remote_devices_mu); - for (tensorflow::Device* d : *devices) { - remote_devices.emplace_back(d); - } - } - counter.DecrementCount(); - }); - } - counter.Wait(); - for (int i = 0; i < num_added_workers; i++) { - TF_RETURN_IF_ERROR(statuses[i]); - } - - TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices))); - return tensorflow::Status::OK(); -} - -tensorflow::Status GetAllRemoteDevices( - const std::vector& remote_workers, - tensorflow::WorkerCacheInterface* worker_cache, - std::unique_ptr* device_mgr) { - auto remote_device_mgr = absl::make_unique(); - TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache, - remote_device_mgr.get())); - *device_mgr = std::move(remote_device_mgr); - return tensorflow::Status::OK(); -} - -tensorflow::Status RemoveRemoteDevicesFromMgr( - const std::vector& removed_remote_workers, - tensorflow::DynamicDeviceMgr* remote_device_mgr) { - const std::vector remote_devices = - (remote_device_mgr->ListDevices()); - std::vector devices_to_remove; - for (tensorflow::Device* d : remote_devices) { - for (const string& remote_worker : removed_remote_workers) { - if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker, - d->name())) { - devices_to_remove.emplace_back(d); - break; - } - } - } - TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove)); - return tensorflow::Status::OK(); -} - -tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server, - const string& local_worker, - std::vector* remote_workers) { - tensorflow::GrpcServer* grpc_server = - dynamic_cast(server); - if (grpc_server == nullptr) { - return tensorflow::errors::Internal( - "Currently, TFE_NewContext only supports tensorflow::GrpcServer."); - } - grpc_server->master_env()->worker_cache->ListWorkers(remote_workers); - remote_workers->erase( - std::remove(remote_workers->begin(), remote_workers->end(), local_worker), - remote_workers->end()); - return tensorflow::Status::OK(); -} - -void DifferentiateWorkerLists(const std::vector* current_list, - const std::vector* new_list, - std::vector* added, - std::vector* removed, - std::vector* existing) { - // Get STL set_difference and set_intersection with one list traversal. - // Similar to the set_difference library function, the input lists - // (`current_list` and `new_list`) must be sorted before calling the function. - added->resize(new_list->size()); - removed->resize(current_list->size()); - existing->resize(current_list->size()); - std::vector::const_iterator curr_it = current_list->begin(); - std::vector::const_iterator new_it = new_list->begin(); - std::vector::iterator added_it = added->begin(); - std::vector::iterator removed_it = removed->begin(); - std::vector::iterator existing_it = existing->begin(); - while (curr_it != current_list->end() && new_it != new_list->end()) { - if (*curr_it < *new_it) { - *removed_it++ = *curr_it++; - } else if (*curr_it > *new_it) { - *added_it++ = *new_it++; - } else { - *existing_it++ = *curr_it++; - new_it++; - } - } - removed_it = std::copy(curr_it, current_list->end(), removed_it); - added_it = std::copy(new_it, new_list->end(), added_it); - added->resize(added_it - added->begin()); - removed->resize(removed_it - removed->begin()); - existing->resize(existing_it - existing->begin()); -} - -tensorflow::Status GetReplacedFromExistingWorkers( - const std::vector* existing_workers, tensorflow::uint64 context_id, - tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def, - tensorflow::eager::EagerClientCache* client_cache, - std::vector* replaced_workers) { - tensorflow::BlockingCounter counter(existing_workers->size()); - std::vector statuses(existing_workers->size()); - tensorflow::eager::KeepAliveRequest request; - request.set_context_id(context_id); - std::vector responses( - existing_workers->size()); - for (int i = 0; i < existing_workers->size(); i++) { - tensorflow::core::RefCountPtr eager_client; - statuses[i] = - client_cache->GetClient(existing_workers->at(i), &eager_client); - if (!statuses[i].ok()) { - counter.DecrementCount(); - continue; - } - eager_client->KeepAliveAsync( - &request, &responses[i], - [i, &statuses, &counter](const tensorflow::Status& s) { - statuses[i] = s; - counter.DecrementCount(); - }); - } - counter.Wait(); - for (int i = 0; i < existing_workers->size(); i++) { - // If the RPC fails (indicating that the requested ID doesn't exist on - // remote), or the returned view ID is not equal to the local one - // (indicating that the remote worker has a stale view of cluster), treat - // the worker as replaced. - if (!statuses[i].ok() || - responses[i].context_view_id() != context_view_id) { - replaced_workers->emplace_back(existing_workers->at(i)); - } - } - return tensorflow::Status::OK(); -} - -tensorflow::Status CreateRemoteContexts( - TFE_Context* ctx, const std::vector& remote_workers, - tensorflow::uint64 context_id, tensorflow::uint64 context_view_id, - int keep_alive_secs, const tensorflow::ServerDef& server_def, - tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, - const bool lazy_copy_remote_function_inputs, - const tensorflow::eager::CreateContextRequest& base_request) { - int num_remote_workers = remote_workers.size(); - tensorflow::BlockingCounter counter(num_remote_workers); - std::vector statuses(num_remote_workers); - for (int i = 0; i < num_remote_workers; i++) { - const string& remote_worker = remote_workers[i]; - tensorflow::DeviceNameUtils::ParsedName parsed_name; - if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, - &parsed_name)) { - statuses[i] = tensorflow::errors::InvalidArgument( - "Unable to parse ", remote_worker, " as a device name"); - counter.DecrementCount(); - continue; - } - - tensorflow::core::RefCountPtr eager_client; - statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client); - if (eager_client == nullptr) { - statuses[i] = tensorflow::errors::Internal( - "Cannot find a client for the given target:", remote_worker); - } - if (!statuses[i].ok()) { - counter.DecrementCount(); - continue; - } - - tensorflow::eager::CreateContextRequest request; - tensorflow::eager::CreateContextResponse* response = - new tensorflow::eager::CreateContextResponse(); - request.set_context_id(context_id); - request.set_context_view_id(context_view_id); - *request.mutable_server_def() = server_def; - request.mutable_server_def()->set_job_name(parsed_name.job); - request.mutable_server_def()->set_task_index(parsed_name.task); - request.mutable_server_def()->mutable_default_session_config()->MergeFrom( - server_def.default_session_config()); - - std::vector filtered_device_mask; - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->FilterDevicesForRemoteWorkers( - remote_worker, base_request.cluster_device_attributes(), - &filtered_device_mask); - DCHECK_EQ(filtered_device_mask.size(), - base_request.cluster_device_attributes_size()); - for (int i = 0; i < filtered_device_mask.size(); i++) { - if (filtered_device_mask[i]) { - const auto& da = base_request.cluster_device_attributes(i); - *request.add_cluster_device_attributes() = da; - } - } - request.set_async(async); - request.set_keep_alive_secs(keep_alive_secs); - request.set_lazy_copy_remote_function_inputs( - lazy_copy_remote_function_inputs); - - eager_client->CreateContextAsync( - &request, response, - [i, &statuses, &counter, response](const tensorflow::Status& s) { - statuses[i] = s; - delete response; - counter.DecrementCount(); - }); - } - counter.Wait(); - tensorflow::StatusGroup sg; - for (int i = 0; i < num_remote_workers; i++) { - if (TF_PREDICT_FALSE(!statuses[i].ok())) { - sg.Update(statuses[i]); - } - } - return sg.as_summary_status(); -} - -tensorflow::Status UpdateRemoteContexts( - TFE_Context* ctx, const std::vector& remote_workers, - const std::vector& added_workers, - const std::vector& removed_workers, tensorflow::uint64 context_id, - tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def, - tensorflow::eager::EagerClientCache* remote_eager_workers, - const tensorflow::eager::CreateContextRequest& base_request) { - int num_remote_workers = remote_workers.size(); - tensorflow::BlockingCounter counter(num_remote_workers); - std::vector statuses(num_remote_workers); - - int cluster_device_count = base_request.cluster_device_attributes_size(); - std::unordered_set added_or_removed(added_workers.begin(), - added_workers.end()); - std::copy(removed_workers.begin(), removed_workers.end(), - std::inserter(added_or_removed, added_or_removed.end())); - // Whether each device is in the updated (added or removed) workers - std::vector device_added_or_removed(cluster_device_count); - for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) { - const auto& da = base_request.cluster_device_attributes().at(i); - tensorflow::DeviceNameUtils::ParsedName pn; - tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn); - string task_name; - tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name); - if (added_or_removed.find(task_name) != added_or_removed.end()) { - device_added_or_removed[i] = true; - } - } - - for (int i = 0; i < num_remote_workers; i++) { - const string& remote_worker = remote_workers[i]; - tensorflow::DeviceNameUtils::ParsedName parsed_name; - if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, - &parsed_name)) { - statuses[i] = tensorflow::errors::InvalidArgument( - "Unable to parse ", remote_worker, " as a device name"); - counter.DecrementCount(); - continue; - } - - tensorflow::core::RefCountPtr eager_client; - statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client); - if (eager_client == nullptr) { - statuses[i] = tensorflow::errors::Internal( - "Cannot find a client for the given target:", remote_worker); - } - if (!statuses[i].ok()) { - counter.DecrementCount(); - continue; - } - - std::vector filtered_device_mask; - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->FilterDevicesForRemoteWorkers( - remote_worker, base_request.cluster_device_attributes(), - &filtered_device_mask); - DCHECK_EQ(filtered_device_mask.size(), cluster_device_count); - - // If any of the devices that match the device filters are in the set of - // added or removed workers, we must send a complete UpdateContextRequest. - // Otherwise, only send a simple request to increment context view ID. - std::vector added_or_removed_filtered_devices(cluster_device_count); - std::transform(device_added_or_removed.begin(), - device_added_or_removed.end(), filtered_device_mask.begin(), - added_or_removed_filtered_devices.begin(), - std::logical_and()); - const bool full_update_request = - std::accumulate(added_or_removed_filtered_devices.begin(), - added_or_removed_filtered_devices.end(), false, - std::logical_or()); - - tensorflow::eager::UpdateContextRequest request; - auto* response = new tensorflow::eager::UpdateContextResponse(); - request.set_context_id(context_id); - request.set_context_view_id(context_view_id); - if (full_update_request) { - *request.mutable_server_def() = server_def; - request.mutable_server_def()->set_job_name(parsed_name.job); - request.mutable_server_def()->set_task_index(parsed_name.task); - request.mutable_server_def()->mutable_default_session_config()->MergeFrom( - server_def.default_session_config()); - for (int i = 0; i < cluster_device_count; i++) { - if (filtered_device_mask[i]) { - const auto& da = base_request.cluster_device_attributes(i); - *request.add_cluster_device_attributes() = da; - } - } - } - - eager_client->UpdateContextAsync( - &request, response, - [i, &statuses, &counter, response](const tensorflow::Status& s) { - statuses[i] = s; - delete response; - counter.DecrementCount(); - }); - } - counter.Wait(); - for (int i = 0; i < num_remote_workers; i++) { - TF_RETURN_IF_ERROR(statuses[i]); - } - return tensorflow::Status::OK(); -} - -tensorflow::Status UpdateTFE_ContextWithServerDef( - int keep_alive_secs, const tensorflow::ServerDef& server_def, - TFE_Context* ctx, bool reset_context) { - // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the - // server object (which currently CHECK-fails) and we miss the error, instead, - // we log the error, and then return to allow the user to see the error - // message. -#define LOG_AND_RETURN_IF_ERROR(...) \ - do { \ - const ::tensorflow::Status _status = (__VA_ARGS__); \ - if (TF_PREDICT_FALSE(!_status.ok())) { \ - LOG(ERROR) << _status.error_message(); \ - return _status; \ - } \ - } while (0); - - string worker_name = - tensorflow::strings::StrCat("/job:", server_def.job_name(), - "/replica:0/task:", server_def.task_index()); - - // List of current remote workers before updating server_def. Unused if - // resetting the server_def. - std::vector curr_remote_workers; - // List of updated remote workers. - std::vector remote_workers; - - // New server created for new server_def. Unused if updating server_def. - std::unique_ptr new_server; - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - tensorflow::GrpcServer* grpc_server; - if (reset_context) { - const tensorflow::DeviceMgr* device_mgr = - AreLocalDevicesCompatible(context, server_def) - ? context->local_device_mgr() - : nullptr; - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions( - server_def, {device_mgr}, &new_server)); - grpc_server = dynamic_cast(new_server.get()); - LOG_AND_RETURN_IF_ERROR( - ListRemoteWorkers(new_server.get(), worker_name, &remote_workers)); - } else { - LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name, - &curr_remote_workers)); - // No need to check the cast here, since `ListRemoteWorkers` already checks - // if the server is a GRPC server or not. - grpc_server = dynamic_cast(context->GetServer()); - LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); - LOG_AND_RETURN_IF_ERROR( - ListRemoteWorkers(grpc_server, worker_name, &remote_workers)); - } - - tensorflow::uint64 context_id = context->GetContextId(); - tensorflow::uint64 context_view_id = context->GetContextViewId(); - if (reset_context) { - context_id = tensorflow::EagerContext::NewContextId(); - context_view_id = 0; - // Make master eager context accessible by local eager service, which might - // receive send tensor requests from remote workers. - LOG_AND_RETURN_IF_ERROR( - grpc_server->AddMasterEagerContextToEagerService(context_id, context)); - } - - std::unique_ptr remote_eager_workers; - LOG_AND_RETURN_IF_ERROR( - grpc_server->master_env()->worker_cache->GetEagerClientCache( - &remote_eager_workers)); - - // For cluster update, use a status group to aggregate statuses from - // * adding and removing remote devices - // * creating remote contexts on newly added workers - // * updating remote contexts on existing workers - // * updating the master context - // Note that we should not return immediately on errors in the middle of these - // updates to prevent cluster from having inconsistent context views. - // - // Unused if `reset_context` is True. - tensorflow::StatusGroup sg; - - // When updating an existing context, populate the following lists with: - // * added_workers: set(remote_workers) - set(curr_remote_workers) - // * removed_workers: set(curr_remote_workers) - set(remote_workers) - // * existing_workers: set(curr_remote_workers) intersect set(remote_workers) - // * replaced_workers: workers with the same task names and potentially the - // same `hostname:port`s, but replaced by different processes - std::vector added_workers; - std::vector removed_workers; - std::vector existing_workers; - std::vector replaced_workers; - - // New remote device manager created for new server_def. Unused if updating - // server_def. - std::unique_ptr new_remote_device_mgr; - tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr; - if (reset_context) { - LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices( - remote_workers, grpc_server->master_env()->worker_cache, - &new_remote_device_mgr)); - remote_device_mgr = new_remote_device_mgr.get(); - } else { - context->ClearCachesAndDefaultExecutor(); - // TODO(b/143914772): Potential memory leak if rendezvous has pending - // tensors for removed / replaced workers. - - remote_device_mgr = context->GetOwnedRemoteDeviceMgr(); - if (remote_device_mgr == nullptr) { - LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument( - "Updating context with an invalid set of remote devices.")); - } - std::sort(curr_remote_workers.begin(), curr_remote_workers.end()); - std::sort(remote_workers.begin(), remote_workers.end()); - DifferentiateWorkerLists(&curr_remote_workers, &remote_workers, - &added_workers, &removed_workers, - &existing_workers); - sg.Update(GetReplacedFromExistingWorkers( - &existing_workers, context_id, context->GetContextViewId(), server_def, - remote_eager_workers.get(), &replaced_workers)); - if (VLOG_IS_ON(1)) { - VLOG(1) << "Updating cluster with following changes"; - for (const string& w : added_workers) VLOG(1) << " Added worker " << w; - for (const string& w : removed_workers) - VLOG(1) << " Removed worker " << w; - for (const string& w : replaced_workers) - VLOG(1) << " Replaced worker " << w; - } - if (!replaced_workers.empty()) { - // Treat replaced workers as removed then added back, so that we recreate - // remote devices and contexts, and re-register functions on those workers - removed_workers.insert(removed_workers.end(), replaced_workers.begin(), - replaced_workers.end()); - added_workers.insert(added_workers.end(), replaced_workers.begin(), - replaced_workers.end()); - for (const string& w : replaced_workers) { - existing_workers.erase( - std::remove(existing_workers.begin(), existing_workers.end(), w), - existing_workers.end()); - } - } - sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr)); - sg.Update(AddRemoteDevicesToMgr(added_workers, - grpc_server->master_env()->worker_cache, - remote_device_mgr)); - } - - std::vector cluster_device_attributes; - remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes); - - std::vector local_device_attributes; - grpc_server->worker_env()->device_mgr->ListDeviceAttributes( - &local_device_attributes); - - // This request make sure that we can create Rendezvous properly between - // Local and Remote context. - tensorflow::eager::CreateContextRequest base_request; - for (const auto& da : cluster_device_attributes) { - *base_request.add_cluster_device_attributes() = da; - } - for (const auto& da : local_device_attributes) { - *base_request.add_cluster_device_attributes() = da; - } - - // Initialize remote eager workers. - if (reset_context) { - const tensorflow::Status s = CreateRemoteContexts( - ctx, remote_workers, context_id, context_view_id, keep_alive_secs, - server_def, remote_eager_workers.get(), context->Executor().Async(), - context->LazyCopyFunctionRemoteInputs(), base_request); - // NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause - // the CreateRemoteContexts to fail. We currently only log instead of - // directly returning the error, since returning here will cause the server - // object to be destroyed (which currently CHECK-fails). The client will - // see additional errors if ops are subsequently sent to the failed workers. - if (TF_PREDICT_FALSE(!s.ok())) { - LOG(ERROR) << "Error when creating contexts on remote targets: " - << s.error_message() - << "\nExecuting remote ops or functions on these remote " - "targets will fail."; - } - } else { - if (sg.ok()) { - // Create remote contexts on the newly added workers only if the master - // has collected all device information from them (i.e., the - // GetAllRemoteDevices call returns succussfully). Note that in rare cases - // GetAllRemoteDevices can still fail even with RPCs configured to wait - // until the remote workers to become alive. If the master creates remote - // contexts on the workers whose devices are still not collected, those - // workers will be treated as existing workers subsequently, so the master - // will never get devices from them even with retrying UpdateServerDef. - sg.Update(CreateRemoteContexts( - ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, - server_def, remote_eager_workers.get(), context->Executor().Async(), - context->LazyCopyFunctionRemoteInputs(), base_request)); - } - if (!existing_workers.empty()) { - if (VLOG_IS_ON(1)) { - for (const string& w : existing_workers) { - VLOG(1) << "Updating cluster with existing worker " << w; - } - } - // The master's context_view_id will be incremented by one in the - // UpdateRemoteMaster call later. We want existing workers to also have - // the updated context_view_id, so we must set their context_view_id to - // the master's current context_view_id + 1. - sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers, - removed_workers, context_id, - context_view_id + 1, server_def, - remote_eager_workers.get(), base_request)); - } - } - - auto session_name = tensorflow::strings::StrCat("eager_", context_id); - if (reset_context) { - tensorflow::RemoteRendezvous* r = - grpc_server->worker_env()->rendezvous_mgr->Find(context_id); - auto* device_mgr = grpc_server->worker_env()->device_mgr; - std::shared_ptr worker_session; - LOG_AND_RETURN_IF_ERROR( - grpc_server->worker_env()->session_mgr->CreateSession( - session_name, server_def, base_request.cluster_device_attributes(), - true)); - LOG_AND_RETURN_IF_ERROR( - grpc_server->worker_env()->session_mgr->WorkerSessionForSession( - session_name, &worker_session)); - - // Initialize remote tensor communication based on worker session. - LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get())); - - tensorflow::DistributedFunctionLibraryRuntime* cluster_flr = - tensorflow::eager::CreateClusterFLR(context_id, context, - worker_session.get()); - auto remote_mgr = absl::make_unique( - /*is_master=*/true, context); - - LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster( - std::move(new_server), grpc_server->worker_env(), worker_session, - std::move(remote_eager_workers), std::move(new_remote_device_mgr), - remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr, - std::move(remote_mgr))); - - // NOTE: We start the server after all other initialization, because the - // GrpcServer cannot be destroyed after it is started. - LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); - } else { - sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession( - session_name, server_def, base_request.cluster_device_attributes(), - /*isolate_session_state=*/true)); - sg.Update(context->UpdateRemoteMaster(context_id, - std::move(remote_eager_workers), - added_workers, removed_workers)); - LOG_AND_RETURN_IF_ERROR(sg.as_summary_status()); - } -#undef LOG_AND_RETURN_IF_ERROR - - return tensorflow::Status::OK(); -} -#endif // !IS_MOBILE_PLATFORM - } // namespace extern "C" { @@ -735,7 +105,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { #else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); return nullptr; -#endif +#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE } std::vector> devices; status->status = tensorflow::DeviceFactory::AddDevices( @@ -747,13 +117,18 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - - return tensorflow::wrap(new tensorflow::EagerContext( + tensorflow::EagerContext* eager_context = new tensorflow::EagerContext( opts->session_options.options, static_cast( opts->device_placement_policy), opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), - /*device_mgr_owned*/ true, r)); + /*device_mgr_owned*/ true, r); +#if !defined(IS_MOBILE_PLATFORM) + eager_context->SetDistributedManager( + std::make_unique( + eager_context)); +#endif // !IS_MOBILE_PLATFORM + return tensorflow::wrap(eager_context); } void TFE_DeleteContext(TFE_Context* ctx) { @@ -791,26 +166,9 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, "Invalid tensorflow.ServerDef protocol buffer"); return; } - if (server_def.has_cluster_device_filters()) { - const auto& cdf = server_def.cluster_device_filters(); - for (const auto& jdf : cdf.jobs()) { - const string remote_prefix = "/job:" + jdf.name() + "/task:"; - for (const auto& tdf : jdf.tasks()) { - const int32_t task_index = tdf.first; - std::vector device_filters(tdf.second.device_filters_size()); - for (int i = 0; i < tdf.second.device_filters_size(); i++) { - device_filters[i] = tdf.second.device_filters(i); - } - const string remote_worker = remote_prefix + std::to_string(task_index); - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = - context->SetRemoteDeviceFilters(remote_worker, device_filters); - } - } - } - status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, - ctx, /*reset_context=*/true); + status->status = + tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef( + server_def, /*reset_context=*/true, keep_alive_secs); #endif // !IS_MOBILE_PLATFORM } @@ -835,14 +193,9 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, status->status = tensorflow::errors::InvalidArgument( "Trying to update a context with invalid context id."); } - if (server_def.has_cluster_device_filters()) { - LOG(WARNING) << "Device filters can only be specified when initializing " - "the cluster. Any changes in device filters are ignored " - "when updating the server def."; - } - // TODO(haoyuzhang): Check server_def compatibility before the update - status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, - ctx, /*reset_context=*/false); + status->status = + tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef( + server_def, /*reset_context=*/false, keep_alive_secs); #endif // !IS_MOBILE_PLATFORM } @@ -854,44 +207,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, "TFE_ContextSetServerDef not supported on mobile"); return false; #else // !defined(IS_MOBILE_PLATFORM) - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - tensorflow::GrpcServer* grpc_server = - dynamic_cast(context->GetServer()); - if (grpc_server == nullptr) { - status->status = - tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer."); - return false; - } - tensorflow::WorkerInterface* wi = - grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name); - if (wi == nullptr) { - status->status = tensorflow::errors::InvalidArgument( - "Unable to find worker interface corresponding to task ", worker_name); - return false; - } - - tensorflow::GetStatusRequest request; - tensorflow::GetStatusResponse response; - tensorflow::Status remote_status; - tensorflow::Notification done; - wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true, - [&remote_status, &done](const tensorflow::Status& s) { - remote_status = s; - done.Notify(); - }); - done.WaitForNotification(); - - // We set OK status so the call does not raise any exceptions. Instead, caller - // users the return value to tell if the remote worker is alive. - status->status = tensorflow::Status::OK(); - - if (remote_status.ok()) { - return true; - } - LOG(INFO) << "Remote worker " << worker_name - << " is not alive: " << remote_status.error_message(); - return false; + bool is_alive; + status->status = + tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive( + worker_name, &is_alive); + return is_alive; #endif // !IS_MOBILE_PLATFORM } diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index b0fb9ca51d3..1d1f348fbff 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/immediate_execution_distributed_manager.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/tensor_interface.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/protobuf/config.pb.h" @@ -174,6 +176,18 @@ class ImmediateExecutionContext : public AbstractContext { virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface( ImmediateExecutionTensorHandle* handle) = 0; + //===--------------------------------------------------------------------===// + // Distributed runtime related functions. + //===--------------------------------------------------------------------===// +#if !defined(IS_MOBILE_PLATFORM) + // Set a distributed manager that helps set up, update, and check liveness + // of member tasks in the cluster. + virtual void SetDistributedManager( + std::unique_ptr distributed) = 0; + + virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0; +#endif // !IS_MOBILE_PLATFORM + protected: explicit ImmediateExecutionContext(AbstractContextKind kind) : AbstractContext(kind) {} diff --git a/tensorflow/c/eager/immediate_execution_distributed_manager.h b/tensorflow/c/eager/immediate_execution_distributed_manager.h new file mode 100644 index 00000000000..1b7ce2ed156 --- /dev/null +++ b/tensorflow/c/eager/immediate_execution_distributed_manager.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_ +#define TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_ + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +namespace tensorflow { +class ImmediateExecutionContext; + +class ImmediateExecutionDistributedManager { + public: + virtual ~ImmediateExecutionDistributedManager() {} + + // Set up distributed execution environment on local and remote tasks. + // When `reset_context` is true, initialize new cluster context state based on + // cluster configurations provided in `server_def`; otherwise, update existing + // context state with the provided `server_def`. + // Contexts created on remote tasks will be considered stale and garbage + // collected after `keep_alive_secs` of inactivity. + virtual Status SetOrUpdateServerDef(const ServerDef& server_def, + bool reset_context, + int keep_alive_secs) = 0; + + // Check if the remote task is alive. + virtual Status CheckRemoteAlive(const std::string& remote_task_name, + bool* is_alive) = 0; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_ diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 91f60d6ebe2..728aacb36e4 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -80,6 +80,7 @@ tf_cuda_library( "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_tensor_internal", "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_distributed_manager", "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", @@ -109,6 +110,47 @@ tf_cuda_library( }), ) +tf_cuda_library( + name = "context_distributed_manager", + srcs = [ + "context_distributed_manager.cc", + ], + hdrs = [ + "context_distributed_manager.h", + ], + visibility = ["//tensorflow:internal"], + deps = [ + ":context", + ":eager_executor", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_distributed_manager", + "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", + "//tensorflow/core/distributed_runtime:worker_env", + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core/distributed_runtime:worker_cache", + "//tensorflow/core/distributed_runtime:worker_interface", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_session", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/eager:cluster_function_library_runtime", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/eager:remote_mgr", + ], + }), +) + tf_cc_test( name = "context_test", srcs = ["context_test.cc"], diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 235aa58757d..1e9516d5a69 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -24,15 +24,8 @@ limitations under the License. #include #include -// clang-format off -// Required for IS_MOBILE_PLATFORM -#include "tensorflow/core/common_runtime/process_function_library_runtime.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/platform.h" -// clang-format on - -#include "absl/types/optional.h" #include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -40,12 +33,34 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" + +// "tensorflow/core/platform/platform.h" must be included first before using +// IS_MOBILE_PLATFORM. #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" @@ -53,23 +68,6 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #endif // !IS_MOBILE_PLATFORM -#include "tensorflow/core/framework/collective.h" -#include "tensorflow/core/framework/log_memory.h" -#include "tensorflow/core/framework/rendezvous.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/gtl/map_util.h" - -#include "tensorflow/core/platform/casts.h" -#include "tensorflow/core/platform/fingerprint.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/public/session_options.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -442,6 +440,15 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { return ptr->getKind() == kEager; } + // Function to support distributed C API. + void SetDistributedManager( + std::unique_ptr distributed) + override { + distributed_manager_ = std::move(distributed); + } + ImmediateExecutionDistributedManager* GetDistributedManager() override { + return distributed_manager_.get(); + } #endif // IS_MOBILE_PLATFORM // Closes remote eager contexts, waits for all RPCs to finish, and @@ -671,6 +678,10 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { std::unordered_map> cluster_device_filters_ TF_GUARDED_BY(remote_state_mu_); + // A distributed manager that helps setup, update, and check liveness of + // member tasks in the cluster. + std::unique_ptr distributed_manager_; + #endif // IS_MOBILE_PLATFORM // For a multi device function, the target device of each input is unknown diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc new file mode 100644 index 00000000000..31243149613 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc @@ -0,0 +1,716 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h" + +#include "tensorflow/core/common_runtime/copy_tensor.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/blocking_counter.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/device_filters.pb.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/core/util/device_name_utils.h" + +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" +#include "tensorflow/core/distributed_runtime/remote_device.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#endif // !IS_MOBILE_PLATFORM + +namespace tensorflow { +#if !defined(IS_MOBILE_PLATFORM) +namespace { +bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context, + const tensorflow::ServerDef& server_def) { + if (server_def.job_name() != context->HostCPU()->parsed_name().job) { + return false; + } + return server_def.default_session_config().SerializeAsString() == + context->session_options().config.SerializeAsString(); +} + +tensorflow::Status AddRemoteDevicesToMgr( + const std::vector& added_remote_workers, + tensorflow::WorkerCacheInterface* worker_cache, + tensorflow::DynamicDeviceMgr* remote_device_mgr) { + std::vector> remote_devices; + tensorflow::mutex remote_devices_mu; + int num_added_workers = added_remote_workers.size(); + tensorflow::BlockingCounter counter(num_added_workers); + std::vector statuses(num_added_workers); + for (int i = 0; i < num_added_workers; i++) { + tensorflow::NewRemoteDevices( + tensorflow::Env::Default(), worker_cache, added_remote_workers[i], + [i, &statuses, &counter, &remote_devices, &remote_devices_mu]( + const tensorflow::Status& s, + std::vector* devices) { + statuses[i] = s; + if (s.ok()) { + tensorflow::mutex_lock l(remote_devices_mu); + for (tensorflow::Device* d : *devices) { + remote_devices.emplace_back(d); + } + } + counter.DecrementCount(); + }); + } + counter.Wait(); + for (int i = 0; i < num_added_workers; i++) { + TF_RETURN_IF_ERROR(statuses[i]); + } + + TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices))); + return tensorflow::Status::OK(); +} + +tensorflow::Status GetAllRemoteDevices( + const std::vector& remote_workers, + tensorflow::WorkerCacheInterface* worker_cache, + std::unique_ptr* device_mgr) { + auto remote_device_mgr = std::make_unique(); + TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache, + remote_device_mgr.get())); + *device_mgr = std::move(remote_device_mgr); + return tensorflow::Status::OK(); +} + +tensorflow::Status RemoveRemoteDevicesFromMgr( + const std::vector& removed_remote_workers, + tensorflow::DynamicDeviceMgr* remote_device_mgr) { + const std::vector remote_devices = + (remote_device_mgr->ListDevices()); + std::vector devices_to_remove; + for (tensorflow::Device* d : remote_devices) { + for (const string& remote_worker : removed_remote_workers) { + if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker, + d->name())) { + devices_to_remove.emplace_back(d); + break; + } + } + } + TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server, + const string& local_worker, + std::vector* remote_workers) { + tensorflow::GrpcServer* grpc_server = + dynamic_cast(server); + if (grpc_server == nullptr) { + return tensorflow::errors::Internal( + "Currently, TFE_NewContext only supports tensorflow::GrpcServer."); + } + grpc_server->master_env()->worker_cache->ListWorkers(remote_workers); + remote_workers->erase( + std::remove(remote_workers->begin(), remote_workers->end(), local_worker), + remote_workers->end()); + return tensorflow::Status::OK(); +} + +void DifferentiateWorkerLists(const std::vector* current_list, + const std::vector* new_list, + std::vector* added, + std::vector* removed, + std::vector* existing) { + // Get STL set_difference and set_intersection with one list traversal. + // Similar to the set_difference library function, the input lists + // (`current_list` and `new_list`) must be sorted before calling the function. + added->resize(new_list->size()); + removed->resize(current_list->size()); + existing->resize(current_list->size()); + std::vector::const_iterator curr_it = current_list->begin(); + std::vector::const_iterator new_it = new_list->begin(); + std::vector::iterator added_it = added->begin(); + std::vector::iterator removed_it = removed->begin(); + std::vector::iterator existing_it = existing->begin(); + while (curr_it != current_list->end() && new_it != new_list->end()) { + if (*curr_it < *new_it) { + *removed_it++ = *curr_it++; + } else if (*curr_it > *new_it) { + *added_it++ = *new_it++; + } else { + *existing_it++ = *curr_it++; + new_it++; + } + } + removed_it = std::copy(curr_it, current_list->end(), removed_it); + added_it = std::copy(new_it, new_list->end(), added_it); + added->resize(added_it - added->begin()); + removed->resize(removed_it - removed->begin()); + existing->resize(existing_it - existing->begin()); +} + +tensorflow::Status GetReplacedFromExistingWorkers( + const std::vector* existing_workers, tensorflow::uint64 context_id, + tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def, + tensorflow::eager::EagerClientCache* client_cache, + std::vector* replaced_workers) { + tensorflow::BlockingCounter counter(existing_workers->size()); + std::vector statuses(existing_workers->size()); + tensorflow::eager::KeepAliveRequest request; + request.set_context_id(context_id); + std::vector responses( + existing_workers->size()); + for (int i = 0; i < existing_workers->size(); i++) { + tensorflow::core::RefCountPtr eager_client; + statuses[i] = + client_cache->GetClient(existing_workers->at(i), &eager_client); + if (!statuses[i].ok()) { + counter.DecrementCount(); + continue; + } + eager_client->KeepAliveAsync( + &request, &responses[i], + [i, &statuses, &counter](const tensorflow::Status& s) { + statuses[i] = s; + counter.DecrementCount(); + }); + } + counter.Wait(); + for (int i = 0; i < existing_workers->size(); i++) { + // If the RPC fails (indicating that the requested ID doesn't exist on + // remote), or the returned view ID is not equal to the local one + // (indicating that the remote worker has a stale view of cluster), treat + // the worker as replaced. + if (!statuses[i].ok() || + responses[i].context_view_id() != context_view_id) { + replaced_workers->emplace_back(existing_workers->at(i)); + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status CreateRemoteContexts( + EagerContext* context, const std::vector& remote_workers, + tensorflow::uint64 context_id, tensorflow::uint64 context_view_id, + int keep_alive_secs, const tensorflow::ServerDef& server_def, + tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, + const bool lazy_copy_remote_function_inputs, + const tensorflow::eager::CreateContextRequest& base_request) { + int num_remote_workers = remote_workers.size(); + tensorflow::BlockingCounter counter(num_remote_workers); + std::vector statuses(num_remote_workers); + for (int i = 0; i < num_remote_workers; i++) { + const string& remote_worker = remote_workers[i]; + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, + &parsed_name)) { + statuses[i] = tensorflow::errors::InvalidArgument( + "Unable to parse ", remote_worker, " as a device name"); + counter.DecrementCount(); + continue; + } + + tensorflow::core::RefCountPtr eager_client; + statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client); + if (eager_client == nullptr) { + statuses[i] = tensorflow::errors::Internal( + "Cannot find a client for the given target:", remote_worker); + } + if (!statuses[i].ok()) { + counter.DecrementCount(); + continue; + } + + tensorflow::eager::CreateContextRequest request; + tensorflow::eager::CreateContextResponse* response = + new tensorflow::eager::CreateContextResponse(); + request.set_context_id(context_id); + request.set_context_view_id(context_view_id); + *request.mutable_server_def() = server_def; + request.mutable_server_def()->set_job_name(parsed_name.job); + request.mutable_server_def()->set_task_index(parsed_name.task); + request.mutable_server_def()->mutable_default_session_config()->MergeFrom( + server_def.default_session_config()); + + std::vector filtered_device_mask; + context->FilterDevicesForRemoteWorkers( + remote_worker, base_request.cluster_device_attributes(), + &filtered_device_mask); + DCHECK_EQ(filtered_device_mask.size(), + base_request.cluster_device_attributes_size()); + for (int i = 0; i < filtered_device_mask.size(); i++) { + if (filtered_device_mask[i]) { + const auto& da = base_request.cluster_device_attributes(i); + *request.add_cluster_device_attributes() = da; + } + } + request.set_async(async); + request.set_keep_alive_secs(keep_alive_secs); + request.set_lazy_copy_remote_function_inputs( + lazy_copy_remote_function_inputs); + + eager_client->CreateContextAsync( + &request, response, + [i, &statuses, &counter, response](const tensorflow::Status& s) { + statuses[i] = s; + delete response; + counter.DecrementCount(); + }); + } + counter.Wait(); + tensorflow::StatusGroup sg; + for (int i = 0; i < num_remote_workers; i++) { + if (TF_PREDICT_FALSE(!statuses[i].ok())) { + sg.Update(statuses[i]); + } + } + return sg.as_summary_status(); +} + +tensorflow::Status UpdateRemoteContexts( + EagerContext* context, const std::vector& remote_workers, + const std::vector& added_workers, + const std::vector& removed_workers, tensorflow::uint64 context_id, + tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def, + tensorflow::eager::EagerClientCache* remote_eager_workers, + const tensorflow::eager::CreateContextRequest& base_request) { + int num_remote_workers = remote_workers.size(); + tensorflow::BlockingCounter counter(num_remote_workers); + std::vector statuses(num_remote_workers); + + int cluster_device_count = base_request.cluster_device_attributes_size(); + std::unordered_set added_or_removed(added_workers.begin(), + added_workers.end()); + std::copy(removed_workers.begin(), removed_workers.end(), + std::inserter(added_or_removed, added_or_removed.end())); + // Whether each device is in the updated (added or removed) workers + std::vector device_added_or_removed(cluster_device_count); + for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) { + const auto& da = base_request.cluster_device_attributes().at(i); + tensorflow::DeviceNameUtils::ParsedName pn; + tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn); + string task_name; + tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name); + if (added_or_removed.find(task_name) != added_or_removed.end()) { + device_added_or_removed[i] = true; + } + } + + for (int i = 0; i < num_remote_workers; i++) { + const string& remote_worker = remote_workers[i]; + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, + &parsed_name)) { + statuses[i] = tensorflow::errors::InvalidArgument( + "Unable to parse ", remote_worker, " as a device name"); + counter.DecrementCount(); + continue; + } + + tensorflow::core::RefCountPtr eager_client; + statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client); + if (eager_client == nullptr) { + statuses[i] = tensorflow::errors::Internal( + "Cannot find a client for the given target:", remote_worker); + } + if (!statuses[i].ok()) { + counter.DecrementCount(); + continue; + } + + std::vector filtered_device_mask; + context->FilterDevicesForRemoteWorkers( + remote_worker, base_request.cluster_device_attributes(), + &filtered_device_mask); + DCHECK_EQ(filtered_device_mask.size(), cluster_device_count); + + // If any of the devices that match the device filters are in the set of + // added or removed workers, we must send a complete UpdateContextRequest. + // Otherwise, only send a simple request to increment context view ID. + std::vector added_or_removed_filtered_devices(cluster_device_count); + std::transform(device_added_or_removed.begin(), + device_added_or_removed.end(), filtered_device_mask.begin(), + added_or_removed_filtered_devices.begin(), + std::logical_and()); + const bool full_update_request = + std::accumulate(added_or_removed_filtered_devices.begin(), + added_or_removed_filtered_devices.end(), false, + std::logical_or()); + + tensorflow::eager::UpdateContextRequest request; + auto* response = new tensorflow::eager::UpdateContextResponse(); + request.set_context_id(context_id); + request.set_context_view_id(context_view_id); + if (full_update_request) { + *request.mutable_server_def() = server_def; + request.mutable_server_def()->set_job_name(parsed_name.job); + request.mutable_server_def()->set_task_index(parsed_name.task); + request.mutable_server_def()->mutable_default_session_config()->MergeFrom( + server_def.default_session_config()); + for (int i = 0; i < cluster_device_count; i++) { + if (filtered_device_mask[i]) { + const auto& da = base_request.cluster_device_attributes(i); + *request.add_cluster_device_attributes() = da; + } + } + } + + eager_client->UpdateContextAsync( + &request, response, + [i, &statuses, &counter, response](const tensorflow::Status& s) { + statuses[i] = s; + delete response; + counter.DecrementCount(); + }); + } + counter.Wait(); + for (int i = 0; i < num_remote_workers; i++) { + TF_RETURN_IF_ERROR(statuses[i]); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status UpdateContextWithServerDef( + EagerContext* context, const tensorflow::ServerDef& server_def, + bool reset_context, int keep_alive_secs) { + // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the + // server object (which currently CHECK-fails) and we miss the error, instead, + // we log the error, and then return to allow the user to see the error + // message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.error_message(); \ + return _status; \ + } \ + } while (0); + + string worker_name = + tensorflow::strings::StrCat("/job:", server_def.job_name(), + "/replica:0/task:", server_def.task_index()); + + // List of current remote workers before updating server_def. Unused if + // resetting the server_def. + std::vector curr_remote_workers; + // List of updated remote workers. + std::vector remote_workers; + + // New server created for new server_def. Unused if updating server_def. + std::unique_ptr new_server; + tensorflow::GrpcServer* grpc_server; + if (reset_context) { + const tensorflow::DeviceMgr* device_mgr = + AreLocalDevicesCompatible(context, server_def) + ? context->local_device_mgr() + : nullptr; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions( + server_def, {device_mgr}, &new_server)); + grpc_server = dynamic_cast(new_server.get()); + LOG_AND_RETURN_IF_ERROR( + ListRemoteWorkers(new_server.get(), worker_name, &remote_workers)); + } else { + LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name, + &curr_remote_workers)); + // No need to check the cast here, since `ListRemoteWorkers` already checks + // if the server is a GRPC server or not. + grpc_server = dynamic_cast(context->GetServer()); + LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); + LOG_AND_RETURN_IF_ERROR( + ListRemoteWorkers(grpc_server, worker_name, &remote_workers)); + } + + tensorflow::uint64 context_id = context->GetContextId(); + tensorflow::uint64 context_view_id = context->GetContextViewId(); + if (reset_context) { + context_id = tensorflow::EagerContext::NewContextId(); + context_view_id = 0; + // Make master eager context accessible by local eager service, which might + // receive send tensor requests from remote workers. + LOG_AND_RETURN_IF_ERROR( + grpc_server->AddMasterEagerContextToEagerService(context_id, context)); + } + + std::unique_ptr remote_eager_workers; + LOG_AND_RETURN_IF_ERROR( + grpc_server->master_env()->worker_cache->GetEagerClientCache( + &remote_eager_workers)); + + // For cluster update, use a status group to aggregate statuses from + // * adding and removing remote devices + // * creating remote contexts on newly added workers + // * updating remote contexts on existing workers + // * updating the master context + // Note that we should not return immediately on errors in the middle of these + // updates to prevent cluster from having inconsistent context views. + // + // Unused if `reset_context` is True. + tensorflow::StatusGroup sg; + + // When updating an existing context, populate the following lists with: + // * added_workers: set(remote_workers) - set(curr_remote_workers) + // * removed_workers: set(curr_remote_workers) - set(remote_workers) + // * existing_workers: set(curr_remote_workers) intersect set(remote_workers) + // * replaced_workers: workers with the same task names and potentially the + // same `hostname:port`s, but replaced by different processes + std::vector added_workers; + std::vector removed_workers; + std::vector existing_workers; + std::vector replaced_workers; + + // New remote device manager created for new server_def. Unused if updating + // server_def. + std::unique_ptr new_remote_device_mgr; + tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr; + if (reset_context) { + LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices( + remote_workers, grpc_server->master_env()->worker_cache, + &new_remote_device_mgr)); + remote_device_mgr = new_remote_device_mgr.get(); + } else { + context->ClearCachesAndDefaultExecutor(); + // TODO(b/143914772): Potential memory leak if rendezvous has pending + // tensors for removed / replaced workers. + + remote_device_mgr = context->GetOwnedRemoteDeviceMgr(); + if (remote_device_mgr == nullptr) { + LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument( + "Updating context with an invalid set of remote devices.")); + } + std::sort(curr_remote_workers.begin(), curr_remote_workers.end()); + std::sort(remote_workers.begin(), remote_workers.end()); + DifferentiateWorkerLists(&curr_remote_workers, &remote_workers, + &added_workers, &removed_workers, + &existing_workers); + sg.Update(GetReplacedFromExistingWorkers( + &existing_workers, context_id, context->GetContextViewId(), server_def, + remote_eager_workers.get(), &replaced_workers)); + if (VLOG_IS_ON(1)) { + VLOG(1) << "Updating cluster with following changes"; + for (const string& w : added_workers) VLOG(1) << " Added worker " << w; + for (const string& w : removed_workers) + VLOG(1) << " Removed worker " << w; + for (const string& w : replaced_workers) + VLOG(1) << " Replaced worker " << w; + } + if (!replaced_workers.empty()) { + // Treat replaced workers as removed then added back, so that we recreate + // remote devices and contexts, and re-register functions on those workers + removed_workers.insert(removed_workers.end(), replaced_workers.begin(), + replaced_workers.end()); + added_workers.insert(added_workers.end(), replaced_workers.begin(), + replaced_workers.end()); + for (const string& w : replaced_workers) { + existing_workers.erase( + std::remove(existing_workers.begin(), existing_workers.end(), w), + existing_workers.end()); + } + } + sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr)); + sg.Update(AddRemoteDevicesToMgr(added_workers, + grpc_server->master_env()->worker_cache, + remote_device_mgr)); + } + + std::vector cluster_device_attributes; + remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes); + + std::vector local_device_attributes; + grpc_server->worker_env()->device_mgr->ListDeviceAttributes( + &local_device_attributes); + + // This request make sure that we can create Rendezvous properly between + // Local and Remote context. + tensorflow::eager::CreateContextRequest base_request; + for (const auto& da : cluster_device_attributes) { + *base_request.add_cluster_device_attributes() = da; + } + for (const auto& da : local_device_attributes) { + *base_request.add_cluster_device_attributes() = da; + } + + // Initialize remote eager workers. + if (reset_context) { + const tensorflow::Status s = CreateRemoteContexts( + context, remote_workers, context_id, context_view_id, keep_alive_secs, + server_def, remote_eager_workers.get(), context->Executor().Async(), + context->LazyCopyFunctionRemoteInputs(), base_request); + // NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause + // the CreateRemoteContexts to fail. We currently only log instead of + // directly returning the error, since returning here will cause the server + // object to be destroyed (which currently CHECK-fails). The client will + // see additional errors if ops are subsequently sent to the failed workers. + if (TF_PREDICT_FALSE(!s.ok())) { + LOG(ERROR) << "Error when creating contexts on remote targets: " + << s.error_message() + << "\nExecuting remote ops or functions on these remote " + "targets will fail."; + } + } else { + if (sg.ok()) { + // Create remote contexts on the newly added workers only if the master + // has collected all device information from them (i.e., the + // GetAllRemoteDevices call returns succussfully). Note that in rare cases + // GetAllRemoteDevices can still fail even with RPCs configured to wait + // until the remote workers to become alive. If the master creates remote + // contexts on the workers whose devices are still not collected, those + // workers will be treated as existing workers subsequently, so the master + // will never get devices from them even with retrying UpdateServerDef. + sg.Update(CreateRemoteContexts( + context, added_workers, context_id, context_view_id + 1, + keep_alive_secs, server_def, remote_eager_workers.get(), + context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), + base_request)); + } + if (!existing_workers.empty()) { + if (VLOG_IS_ON(1)) { + for (const string& w : existing_workers) { + VLOG(1) << "Updating cluster with existing worker " << w; + } + } + // The master's context_view_id will be incremented by one in the + // UpdateRemoteMaster call later. We want existing workers to also have + // the updated context_view_id, so we must set their context_view_id to + // the master's current context_view_id + 1. + sg.Update(UpdateRemoteContexts(context, existing_workers, added_workers, + removed_workers, context_id, + context_view_id + 1, server_def, + remote_eager_workers.get(), base_request)); + } + } + + auto session_name = tensorflow::strings::StrCat("eager_", context_id); + if (reset_context) { + tensorflow::RemoteRendezvous* r = + grpc_server->worker_env()->rendezvous_mgr->Find(context_id); + auto* device_mgr = grpc_server->worker_env()->device_mgr; + std::shared_ptr worker_session; + LOG_AND_RETURN_IF_ERROR( + grpc_server->worker_env()->session_mgr->CreateSession( + session_name, server_def, base_request.cluster_device_attributes(), + true)); + LOG_AND_RETURN_IF_ERROR( + grpc_server->worker_env()->session_mgr->WorkerSessionForSession( + session_name, &worker_session)); + + // Initialize remote tensor communication based on worker session. + LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get())); + + tensorflow::DistributedFunctionLibraryRuntime* cluster_flr = + tensorflow::eager::CreateClusterFLR(context_id, context, + worker_session.get()); + auto remote_mgr = std::make_unique( + /*is_master=*/true, context); + + LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster( + std::move(new_server), grpc_server->worker_env(), worker_session, + std::move(remote_eager_workers), std::move(new_remote_device_mgr), + remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr, + std::move(remote_mgr))); + + // NOTE: We start the server after all other initialization, because the + // GrpcServer cannot be destroyed after it is started. + LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); + } else { + sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession( + session_name, server_def, base_request.cluster_device_attributes(), + /*isolate_session_state=*/true)); + sg.Update(context->UpdateRemoteMaster(context_id, + std::move(remote_eager_workers), + added_workers, removed_workers)); + LOG_AND_RETURN_IF_ERROR(sg.as_summary_status()); + } +#undef LOG_AND_RETURN_IF_ERROR + + return tensorflow::Status::OK(); +} +} // namespace + +Status EagerContextDistributedManager::SetOrUpdateServerDef( + const ServerDef& server_def, bool reset_context, int keep_alive_secs) { + if (server_def.has_cluster_device_filters()) { + if (reset_context) { + const auto& cdf = server_def.cluster_device_filters(); + for (const auto& jdf : cdf.jobs()) { + const string remote_prefix = "/job:" + jdf.name() + "/task:"; + for (const auto& tdf : jdf.tasks()) { + const int32_t task_index = tdf.first; + std::vector device_filters(tdf.second.device_filters_size()); + for (int i = 0; i < tdf.second.device_filters_size(); i++) { + device_filters[i] = tdf.second.device_filters(i); + } + const string remote_worker = + strings::StrCat(remote_prefix, task_index); + TF_RETURN_IF_ERROR( + context_->SetRemoteDeviceFilters(remote_worker, device_filters)); + } + } + } else { + LOG(WARNING) << "Device filters can only be specified when initializing " + "the cluster. Any changes in device filters are ignored " + "when updating the server def."; + } + } + return UpdateContextWithServerDef(context_, server_def, reset_context, + keep_alive_secs); +} + +Status EagerContextDistributedManager::CheckRemoteAlive( + const std::string& remote_task_name, bool* is_alive) { + *is_alive = false; + GrpcServer* grpc_server = dynamic_cast(context_->GetServer()); + if (grpc_server == nullptr) { + return errors::Internal("Failed to get eager-compatible server instance."); + } + WorkerInterface* wi = + grpc_server->master_env()->worker_cache->GetOrCreateWorker( + remote_task_name); + if (wi == nullptr) { + return errors::InvalidArgument( + "Unable to find worker interface corresponding to task ", + remote_task_name); + } + + GetStatusRequest request; + GetStatusResponse response; + Status remote_status; + Notification done; + wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true, + [&remote_status, &done](const Status& s) { + remote_status = s; + done.Notify(); + }); + done.WaitForNotification(); + + if (remote_status.ok()) { + *is_alive = true; + } + LOG(INFO) << "Remote worker " << remote_task_name + << " is not alive: " << remote_status.error_message(); + return Status::OK(); +} +#endif // !IS_MOBILE_PLATFORM +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.h b/tensorflow/core/common_runtime/eager/context_distributed_manager.h new file mode 100644 index 00000000000..85a7d98cc15 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_DISTRIBUTED_MANAGER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_DISTRIBUTED_MANAGER_H_ + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_distributed_manager.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +namespace tensorflow { +class EagerContext; + +class EagerContextDistributedManager + : public ImmediateExecutionDistributedManager { + public: + explicit EagerContextDistributedManager(EagerContext* context) + : context_(context) {} + + Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context, + int keep_alive_secs) override; + + Status CheckRemoteAlive(const std::string& remote_task_name, + bool* is_alive) override; + + private: + EagerContext* context_; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_DISTRIBUTED_MANAGER_H_