From 39fd4e7c7b307b4e917488a5ce181cfe63452828 Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Wed, 29 Jan 2020 16:16:12 -0800 Subject: [PATCH] Allow configuring device filters in eager cluster to isolate workers. When device filters are set for remote workers, they will only have access to cluster devices that match the filters. If we need to update the cluster, send complete update request to workers that match the device filters, and only send simplified view ID increment request to other workers. PiperOrigin-RevId: 292240301 Change-Id: Icb221c65231a248369f4a341f954b365eb1a42a9 --- tensorflow/c/eager/c_api.cc | 119 +++++++++++++++--- tensorflow/core/BUILD | 1 + .../core/common_runtime/eager/context.cc | 85 +++++++++++++ .../core/common_runtime/eager/context.h | 24 ++++ .../eager/eager_service_impl.cc | 17 +++ tensorflow/core/protobuf/device_filters.proto | 72 +++++++++++ tensorflow/core/protobuf/eager_service.proto | 5 +- .../core/protobuf/tensorflow_server.proto | 13 +- tensorflow/python/eager/remote.py | 40 +++++- tensorflow/python/eager/remote_test.py | 56 ++++++++- tensorflow/python/training/server_lib.py | 83 ++++++++++++ ...experimental.-cluster-device-filters.pbtxt | 13 ++ .../v1/tensorflow.config.experimental.pbtxt | 4 + .../api/golden/v1/tensorflow.config.pbtxt | 2 +- .../v1/tensorflow.train.-server-def.pbtxt | 7 ++ ...experimental.-cluster-device-filters.pbtxt | 13 ++ .../v2/tensorflow.config.experimental.pbtxt | 4 + .../api/golden/v2/tensorflow.config.pbtxt | 2 +- .../v2/tensorflow.train.-server-def.pbtxt | 7 ++ 19 files changed, 534 insertions(+), 33 deletions(-) create mode 100644 tensorflow/core/protobuf/device_filters.proto create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.config.experimental.-cluster-device-filters.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.config.experimental.-cluster-device-filters.pbtxt diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5aeabf159af..67da9c4f0a4 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/platform.h" // NOLINT #include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/core/protobuf/device_filters.pb.h" #include "tensorflow/core/util/device_name_utils.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -265,9 +266,9 @@ tensorflow::Status GetReplacedFromExistingWorkers( } tensorflow::Status CreateRemoteContexts( - const std::vector& remote_workers, tensorflow::uint64 context_id, - tensorflow::uint64 context_view_id, int keep_alive_secs, - const tensorflow::ServerDef& server_def, + TFE_Context* ctx, const std::vector& remote_workers, + tensorflow::uint64 context_id, tensorflow::uint64 context_view_id, + int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, const bool lazy_copy_remote_function_inputs, const tensorflow::eager::CreateContextRequest& base_request) { @@ -296,7 +297,7 @@ tensorflow::Status CreateRemoteContexts( continue; } - tensorflow::eager::CreateContextRequest request(base_request); + tensorflow::eager::CreateContextRequest request; tensorflow::eager::CreateContextResponse* response = new tensorflow::eager::CreateContextResponse(); request.set_context_id(context_id); @@ -304,6 +305,21 @@ tensorflow::Status CreateRemoteContexts( *request.mutable_server_def() = server_def; request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); + request.mutable_server_def()->mutable_default_session_config()->MergeFrom( + server_def.default_session_config()); + + std::vector filtered_device_mask; + ctx->context->FilterDevicesForRemoteWorkers( + remote_worker, base_request.cluster_device_attributes(), + &filtered_device_mask); + DCHECK_EQ(filtered_device_mask.size(), + base_request.cluster_device_attributes_size()); + for (int i = 0; i < filtered_device_mask.size(); i++) { + if (filtered_device_mask[i]) { + const auto& da = base_request.cluster_device_attributes(i); + *request.add_cluster_device_attributes() = da; + } + } request.set_async(async); request.set_keep_alive_secs(keep_alive_secs); request.set_lazy_copy_remote_function_inputs( @@ -325,13 +341,34 @@ tensorflow::Status CreateRemoteContexts( } tensorflow::Status UpdateRemoteContexts( - const std::vector& remote_workers, tensorflow::uint64 context_id, + TFE_Context* ctx, const std::vector& remote_workers, + const std::vector& added_workers, + const std::vector& removed_workers, tensorflow::uint64 context_id, tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, const tensorflow::eager::CreateContextRequest& base_request) { int num_remote_workers = remote_workers.size(); tensorflow::BlockingCounter counter(num_remote_workers); std::vector statuses(num_remote_workers); + + int cluster_device_count = base_request.cluster_device_attributes_size(); + std::unordered_set added_or_removed(added_workers.begin(), + added_workers.end()); + std::copy(removed_workers.begin(), removed_workers.end(), + std::inserter(added_or_removed, added_or_removed.end())); + // Whether each device is in the updated (added or removed) workers + std::vector device_added_or_removed(cluster_device_count); + for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) { + const auto& da = base_request.cluster_device_attributes().at(i); + tensorflow::DeviceNameUtils::ParsedName pn; + tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn); + string task_name; + tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name); + if (added_or_removed.find(task_name) != added_or_removed.end()) { + device_added_or_removed[i] = true; + } + } + for (int i = 0; i < num_remote_workers; i++) { const string& remote_worker = remote_workers[i]; tensorflow::DeviceNameUtils::ParsedName parsed_name; @@ -354,17 +391,42 @@ tensorflow::Status UpdateRemoteContexts( continue; } + std::vector filtered_device_mask; + ctx->context->FilterDevicesForRemoteWorkers( + remote_worker, base_request.cluster_device_attributes(), + &filtered_device_mask); + DCHECK_EQ(filtered_device_mask.size(), cluster_device_count); + + // If any of the devices that match the device filters are in the set of + // added or removed workers, we must send a complete UpdateContextRequest. + // Otherwise, only send a simple request to increment context view ID. + std::vector added_or_removed_filtered_devices(cluster_device_count); + std::transform(device_added_or_removed.begin(), + device_added_or_removed.end(), filtered_device_mask.begin(), + added_or_removed_filtered_devices.begin(), + std::logical_and()); + const bool full_update_request = + std::accumulate(added_or_removed_filtered_devices.begin(), + added_or_removed_filtered_devices.end(), false, + std::logical_or()); + tensorflow::eager::UpdateContextRequest request; auto* response = new tensorflow::eager::UpdateContextResponse(); - - *request.mutable_server_def() = server_def; - request.mutable_server_def()->set_job_name(parsed_name.job); - request.mutable_server_def()->set_task_index(parsed_name.task); - for (const auto& da : base_request.cluster_device_attributes()) { - *request.add_cluster_device_attributes() = da; - } request.set_context_id(context_id); request.set_context_view_id(context_view_id); + if (full_update_request) { + *request.mutable_server_def() = server_def; + request.mutable_server_def()->set_job_name(parsed_name.job); + request.mutable_server_def()->set_task_index(parsed_name.task); + request.mutable_server_def()->mutable_default_session_config()->MergeFrom( + server_def.default_session_config()); + for (int i = 0; i < cluster_device_count; i++) { + if (filtered_device_mask[i]) { + const auto& da = base_request.cluster_device_attributes(i); + *request.add_cluster_device_attributes() = da; + } + } + } eager_client->UpdateContextAsync( &request, response, @@ -525,15 +587,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( for (const auto& da : local_device_attributes) { *base_request.add_cluster_device_attributes() = da; } - base_request.mutable_server_def() - ->mutable_default_session_config() - ->MergeFrom(server_def.default_session_config()); // Initialize remote eager workers. // TODO(b/138847548) Create remote eager contexts in async mode by default. if (reset_context) { LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, context_id, context_view_id, keep_alive_secs, + ctx, remote_workers, context_id, context_view_id, keep_alive_secs, server_def, remote_eager_workers.get(), context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), base_request)); } else { @@ -543,7 +602,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // we must set their context_view_id to the existing master's // context_view_id + 1. LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - added_workers, context_id, context_view_id + 1, keep_alive_secs, + ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, server_def, remote_eager_workers.get(), context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), base_request)); if (!existing_workers.empty()) { @@ -553,8 +612,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( } } LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts( - existing_workers, context_id, context_view_id + 1, server_def, - remote_eager_workers.get(), base_request)); + ctx, existing_workers, added_workers, removed_workers, context_id, + context_view_id + 1, server_def, remote_eager_workers.get(), + base_request)); } } @@ -709,6 +769,22 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, "Invalid tensorflow.ServerDef protocol buffer"); return; } + if (server_def.has_cluster_device_filters()) { + const auto& cdf = server_def.cluster_device_filters(); + for (const auto& jdf : cdf.jobs()) { + const string& remote_prefix = "/job:" + jdf.name() + "/task:"; + for (const auto& tdf : jdf.tasks()) { + const int32_t task_index = tdf.first; + std::vector device_filters(tdf.second.device_filters_size()); + for (int i = 0; i < tdf.second.device_filters_size(); i++) { + device_filters[i] = tdf.second.device_filters(i); + } + const string remote_worker = remote_prefix + std::to_string(task_index); + status->status = + ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters); + } + } + } status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx, /*reset_context=*/true); #endif // !IS_MOBILE_PLATFORM @@ -733,6 +809,11 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, status->status = tensorflow::errors::InvalidArgument( "Trying to update a context with invalid context id."); } + if (server_def.has_cluster_device_filters()) { + LOG(WARNING) << "Device filters can only be specified when initializing " + "the cluster. Any changes in device filters are ignored " + "when updating the server def."; + } // TODO(haoyuzhang): Check server_def compatibility before the update status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx, /*reset_context=*/false); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b2354839021..634670dd0ac 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -186,6 +186,7 @@ COMMON_PROTO_SRCS = [ "protobuf/config.proto", "protobuf/cluster.proto", "protobuf/debug.proto", + "protobuf/device_filters.proto", "protobuf/device_properties.proto", "protobuf/graph_debug_info.proto", "protobuf/queue_runner.proto", diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 301b75dfa68..9b0aa3a2c31 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -778,6 +778,11 @@ uint64 EagerContext::GetContextViewId() { return context_view_id_; } +void EagerContext::IncrementContextViewId() { + mutex_lock l(remote_state_mu_); + context_view_id_ += 1; +} + // Set collective ops related state in the context. Passing nullptr to // `new_server` will reuse the existing GRPC server in context. Status EagerContext::StoreCollectiveOpsServer( @@ -820,6 +825,86 @@ Status EagerContext::StoreCollectiveOpsServer( return Status::OK(); } +Status EagerContext::SetRemoteDeviceFilters( + const string& remote_worker, const std::vector& device_filters) { + // Get fully specified task name for remote worker + string remote_worker_task_name; + DeviceNameUtils::ParsedName pw; + if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) { + return tensorflow::errors::InvalidArgument( + "Remote worker task name is invalid ", remote_worker); + } + // Force set a replica as the key in cluster device filters map. I.e., if the + // remote worker is `/job:worker/task:0` it then becomes + // `/job:worker/replica:0/task:0`. + pw.has_replica = true; + if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) { + return tensorflow::errors::InvalidArgument( + "Job name and task index must be specified for worker ", remote_worker); + } + + std::vector parsed_filters; + for (auto& filter : device_filters) { + DeviceNameUtils::ParsedName parsed_filter; + if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) { + parsed_filters.emplace_back(parsed_filter); + } else { + return tensorflow::errors::InvalidArgument("Invalid filter: ", filter); + } + } + + if (VLOG_IS_ON(1)) { + VLOG(1) << "Setting device filters for " << remote_worker << ":"; + for (auto& filter : device_filters) { + VLOG(1) << " " << filter; + } + } + mutex_lock l(remote_state_mu_); + cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters); + return Status::OK(); +} + +void EagerContext::FilterDevicesForRemoteWorkers( + const string& remote_worker, + const protobuf::RepeatedPtrField& device_attrs, + std::vector* filtered_device_mask) { + filtered_device_mask->resize(device_attrs.size()); + std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false); + + tf_shared_lock l(remote_state_mu_); + auto it = cluster_device_filters_.find(remote_worker); + // If no filters were specified, all devices should be visible to the worker + if (it == cluster_device_filters_.end() || it->second.empty()) { + std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true); + return; + } + + const std::vector& parsed_filters = it->second; + DeviceNameUtils::ParsedName parsed_remote_worker; + DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker); + for (int i = 0; i < device_attrs.size(); i++) { + DeviceNameUtils::ParsedName pn; + DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn); + if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) { + // If this device is on the remote worker itself, it should be visible + // regardless of device filters + filtered_device_mask->at(i) = true; + continue; + } + for (const auto& pf : parsed_filters) { + if ((!pn.has_job || !pf.has_job || pn.job == pf.job) && + (!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) && + (!pn.has_task || !pf.has_task || pn.task == pf.task) && + (!pn.has_type || !pf.has_type || pn.type == pf.type) && + (!pn.has_id || !pf.has_id || pn.id == pf.id)) { + // Found a match, make it visible, stop processing more device filters + filtered_device_mask->at(i) = true; + break; + } + } + } +} + Status EagerContext::InitializeRemoteMaster( std::unique_ptr server, WorkerEnv* worker_env, std::shared_ptr worker_session, diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index de573410442..44f287d0ad8 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -292,6 +292,7 @@ class EagerContext : public core::RefCounted { uint64 GetContextId(); uint64 GetContextViewId(); + void IncrementContextViewId(); // TODO(nareshmodi): Encapsulate remote state into a separate // class/struct. @@ -358,6 +359,24 @@ class EagerContext : public core::RefCounted { std::unique_ptr new_server, DeviceMgr* device_mgr, CollectiveExecutorMgrInterface* rpc_collective_executor_mgr); + // For the specified remote worker, preprocess and set its device filters. + Status SetRemoteDeviceFilters(const string& remote_worker, + const std::vector& device_filters); + + // For the specified remote worker, apply the stored device filters to the + // list of device attributes following these rules: + // (1) if the remote worker does not have device filters, all devices are + // visible to the worker; + // (2) if the device is on the remote worker, then it is visible; + // (3) if the device matches at least one device filter, then it is visible. + // The result is saved as a boolean vector of the same length (i.e., + // filtered_device_mask) indicating whether each of the devices is visible to + // the remote worker. + void FilterDevicesForRemoteWorkers( + const string& remote_worker, + const protobuf::RepeatedPtrField& device_attrs, + std::vector* filtered_device_mask); + // TODO(fishx): Remove the custom deleter once we remove forward declaration. const std::unique_ptr>& @@ -568,6 +587,11 @@ class EagerContext : public core::RefCounted { std::unique_ptr> remote_mgr_; bool is_master_ GUARDED_BY(remote_state_mu_); + + // Maps from a remote worker to a list of parsed device filters. + std::unordered_map> + cluster_device_filters_ GUARDED_BY(remote_state_mu_); + #endif // IS_MOBILE_PLATFORM // For a multi device function, the target device of each input is unknown diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 90237f85849..d57aeb77b22 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -119,6 +119,13 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, auto* r = env_->rendezvous_mgr->Find(request->context_id()); auto session_name = tensorflow::strings::StrCat("eager_", request->context_id()); + if (VLOG_IS_ON(2)) { + VLOG(2) << "Creating context on /job:" << request->server_def().job_name() + << "/task:" << request->server_def().task_index(); + for (const auto& da : request->cluster_device_attributes()) { + VLOG(2) << " " << da.name(); + } + } TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession( session_name, request->server_def(), request->cluster_device_attributes(), true)); @@ -229,6 +236,16 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, " but received update request at view #", request->context_view_id(), ". View id should only be continuously incremented."); } + if (request->cluster_device_attributes_size() == 0) { + // In this case, the client indicates that the updated `server_def` and + // device info is irrelevant to this worker, since it is not connected to + // the updated ones (likely due to device filter settings). The worker + // simply needs to update view ID and does not update other internal state. + ctx->IncrementContextViewId(); + VLOG(1) << "Processing simplified UpdateContextRequest on " + << ctx->HostCPU()->name(); + return Status::OK(); + } // TODO(b/143914772): Potential memory leak if rendezvous has pending // tensors for removed / replaced workers. diff --git a/tensorflow/core/protobuf/device_filters.proto b/tensorflow/core/protobuf/device_filters.proto new file mode 100644 index 00000000000..0aa38379a2a --- /dev/null +++ b/tensorflow/core/protobuf/device_filters.proto @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; + +option cc_enable_arenas = true; +option java_outer_classname = "DeviceFiltersProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +// This file contains protos to be used when defining a TensorFlow +// cluster. +// +// Configure device filters for remote tasks in the cluster. When associated +// with a ClusterDef in setting up the cluster, a remote task will ignore all +// devices which do not match any of its filters. Device filters must be +// configured at the cluster startup, and cannot be updated once the cluster is +// up and running. +// +// EXAMPLES +// -------- +// +// A two-job cluster with the following ClusterDef: +// +// Cluster: +// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' } +// tasks { key: 1 value: 'worker2:2222' } } +// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } +// tasks { key: 1 value: 'ps1:2222' } } +// +// Set device filters to isolate worker tasks: +// +// ClusterDeviceFilters: +// job { name: 'worker' tasks { key: 0 +// value: device_filter '/job:ps' +// device_filter '/job:worker/task:0' } +// tasks { key: 1 +// value: device_filter '/job:ps' +// device_filter '/job:worker/task:1' } } + +// Defines the device filters for a remote task. +message TaskDeviceFilters { + repeated string device_filters = 1; +} + +// Defines the device filters for tasks in a job. +message JobDeviceFilters { + // The name of this job. + string name = 1; + + // Mapping from task ID to task device filters. + map tasks = 2; +} + +// Defines the device filters for jobs in a cluster. +message ClusterDeviceFilters { + repeated JobDeviceFilters jobs = 1; +} diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 4335d87309a..967df44d3dc 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -109,7 +109,10 @@ message UpdateContextRequest { // Identifies the full cluster, and this particular worker's position within. ServerDef server_def = 1; - // Device attributes in the cluster + // Device attributes in the cluster. + // If this field is empty, it indicates that this is a simple update request + // that only increments the cluster view ID and does not require changes to + // the workers it connects to. repeated DeviceAttributes cluster_device_attributes = 2; // The ID of the context to be updated. A context with the specified ID must diff --git a/tensorflow/core/protobuf/tensorflow_server.proto b/tensorflow/core/protobuf/tensorflow_server.proto index 6ff902cbc97..6b3010ab37e 100644 --- a/tensorflow/core/protobuf/tensorflow_server.proto +++ b/tensorflow/core/protobuf/tensorflow_server.proto @@ -15,14 +15,17 @@ limitations under the License. syntax = "proto3"; -import "tensorflow/core/protobuf/config.proto"; -import "tensorflow/core/protobuf/cluster.proto"; - package tensorflow; + +import "tensorflow/core/protobuf/cluster.proto"; +import "tensorflow/core/protobuf/config.proto"; +import "tensorflow/core/protobuf/device_filters.proto"; + option cc_enable_arenas = true; option java_outer_classname = "ServerProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.distruntime"; + option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf"; // Defines the configuration of a single TensorFlow server. message ServerDef { @@ -51,4 +54,8 @@ message ServerDef { // The server port. If not set, then we identify the port from the job_name. int32 port = 6; + + // Device filters for remote tasks in the cluster. + // NOTE: This is an experimental feature and only effective in TensorFlow 2.x. + ClusterDeviceFilters cluster_device_filters = 7; } diff --git a/tensorflow/python/eager/remote.py b/tensorflow/python/eager/remote.py index 276f2de9842..dcb22c17ff6 100644 --- a/tensorflow/python/eager/remote.py +++ b/tensorflow/python/eager/remote.py @@ -83,7 +83,8 @@ def connect_to_cluster(cluster_spec_or_resolver, job_name="localhost", task_index=0, protocol=None, - make_master_device_default=True): + make_master_device_default=True, + cluster_device_filters=None): """Connects to the given cluster. Will make devices on the cluster available to use. Note that calling this more @@ -93,6 +94,30 @@ def connect_to_cluster(cluster_spec_or_resolver, If the given local job name is not present in the cluster specification, it will be automatically added, using an unused port on the localhost. + Device filters can be specified to isolate groups of remote tasks to avoid + undesired accesses between workers. Workers accessing resources or launching + ops / functions on filtered remote devices will result in errors (unknown + devices). For any remote task, if no device filter is present, all cluster + devices will be visible; if any device filter is specified, it can only + see devices matching at least one filter. Devices on the task itself are + always visible. Device filters can be particially specified. + + For example, for a cluster set up for parameter server training, the following + device filters might be specified: + + ```python + cdf = tf.config.experimental.ClusterDeviceFilters() + # For any worker, only the devices on PS nodes and itself are visible + for i in range(num_workers): + cdf.set_device_filters('worker', i, ['/job:ps']) + # Similarly for any ps, only the devices on workers and itself are visible + for i in range(num_ps): + cdf.set_device_filters('ps', i, ['/job:worker']) + + tf.config.experimental_connect_to_cluster(cluster_def, + cluster_device_filters=cdf) + ``` + Args: cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing the cluster. @@ -105,6 +130,9 @@ def connect_to_cluster(cluster_spec_or_resolver, master becomes the default device to run ops. It won't do anything if a cluster spec is passed. Will throw an error if the caller is currently already in some device scope. + cluster_device_filters: an instance of + `tf.train.experimental/ClusterDeviceFilters` that specify device filters + to the remote tasks in cluster. """ if not context.executing_eagerly(): raise ValueError( @@ -125,6 +153,13 @@ def connect_to_cluster(cluster_spec_or_resolver, "`ClusterResolver`.") cluster_def = copy.deepcopy(cluster_spec.as_cluster_def()) + if cluster_device_filters: + if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters): + cluster_device_filters = copy.deepcopy( + cluster_device_filters._as_cluster_device_filters()) # pylint: disable=protected-access + else: + raise ValueError("`cluster_device_filters` must be an instance of " + "`tf.train.experimental.ClusterDeviceFilters`.") # Automatically add local job, if not part of the cluster spec. if job_name not in cluster_spec.jobs: @@ -140,7 +175,8 @@ def connect_to_cluster(cluster_spec_or_resolver, job_name=job_name, task_index=task_index, protocol=protocol, - default_session_config=context.context().config) + default_session_config=context.context().config, + cluster_device_filters=cluster_device_filters) if context.get_server_def() is None: context.set_server_def(server_def) diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index acafbb2626d..275da732c03 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -290,13 +290,10 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase): def setUp(self): super(MultiJobsTest, self).setUp() - workers, ps = test_util.create_local_cluster(2, 1) + workers, ps = test_util.create_local_cluster(num_workers=2, num_ps=2) cluster = { - 'my_worker': [ - _strip_prefix(workers[0].target, _GRPC_PREFIX), - _strip_prefix(workers[1].target, _GRPC_PREFIX), - ], - 'my_ps': [_strip_prefix(ps[0].target, _GRPC_PREFIX)], + 'my_worker': [_strip_prefix(t.target, _GRPC_PREFIX) for t in workers], + 'my_ps': [_strip_prefix(t.target, _GRPC_PREFIX) for t in ps], } self._cluster = server_lib.ClusterSpec(cluster) self._cluster_resolver = SimpleClusterResolver( @@ -330,6 +327,53 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase): with ops.device('/job:my_worker/task:1/device:CPU:0'): self.assertAllEqual(worker_fn(), 8) + @test_util.eager_lazy_remote_copy_on_and_off + def testSimpleParameterServerWithDeviceFilters(self): + cluster_device_filters = server_lib.ClusterDeviceFilters() + for i in range(2): + cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps']) + cluster_device_filters.set_device_filters('my_ps', i, ['/job:my_worker']) + remote.connect_to_cluster( + self._cluster, cluster_device_filters=cluster_device_filters) + + with ops.device('/job:my_ps/task:0/device:CPU:0'): + v1 = variables.Variable(initial_value=0) + with ops.device('/job:my_ps/task:1/device:CPU:0'): + v2 = variables.Variable(initial_value=10) + + @def_function.function + def worker_fn(): + v1.assign_add(1) + v2.assign_sub(2) + return v1.read_value() + v2.read_value() + + with ops.device('/job:my_worker/task:0/device:CPU:0'): + self.assertAllEqual(worker_fn(), 9) + with ops.device('/job:my_worker/task:1/device:CPU:0'): + self.assertAllEqual(worker_fn(), 8) + + # The following remote call would fail because the ps nodes cannot see each + # other due to the device filters. + with self.assertRaises(errors.InvalidArgumentError) as cm: + with ops.device('/job:my_ps/task:0/device:CPU:0'): + worker_fn().numpy() + self.assertIn('/job:my_ps/replica:0/task:1/device:CPU:0 unknown device', + cm.exception.message) + + with self.assertRaises(errors.InvalidArgumentError) as cm: + with ops.device('/job:my_ps/task:1/device:CPU:0'): + worker_fn().numpy() + self.assertIn('/job:my_ps/replica:0/task:0/device:CPU:0 unknown device', + cm.exception.message) + + with ops.device('/job:my_worker/task:0/device:CPU:0'): + self.assertAllEqual(worker_fn(), 7) + with ops.device('/job:my_worker/task:1/device:CPU:0'): + self.assertAllEqual(worker_fn(), 6) + # Explicitly delete variables to avoid triggering errors when being GC'ed in + # subsequent tests. + del v1, v2 + @test_util.eager_lazy_remote_copy_on_and_off def testConnectWithClusterResolver(self): remote.connect_to_cluster(self._cluster_resolver) diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py index 259a9a16c98..b3e840f2f5a 100644 --- a/tensorflow/python/training/server_lib.py +++ b/tensorflow/python/training/server_lib.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.core.protobuf import device_filters_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.framework import errors @@ -488,3 +489,85 @@ class ClusterSpec(object): raise TypeError("Task address %r must be bytes or unicode" % task_address) job_def.tasks[i] = task_address + + +@tf_export("config.experimental.ClusterDeviceFilters") +class ClusterDeviceFilters(object): + """Represent a collection of device filters for the remote workers in cluster. + + NOTE: this is an experimental API and subject to changes. + + Set device filters for selective jobs and tasks. For each remote worker, the + device filters are a list of strings. When any filters are present, the remote + worker will ignore all devices which do not match any of its filters. Each + filter can be partially specified, e.g. "/job:ps", "/job:worker/replica:3", + etc. Note that a device is always visible to the worker it is located on. + + For example, to set the device filters for a parameter server cluster: + + ```python + cdf = tf.config.experimental.ClusterDeviceFilters() + for i in range(num_workers): + cdf.set_device_filters('worker', i, ['/job:ps']) + for i in range(num_ps): + cdf.set_device_filters('ps', i, ['/job:worker']) + + tf.config.experimental_connect_to_cluster(cluster_def, + cluster_device_filters=cdf) + ``` + + The device filters can be partically specified. For remote tasks that do not + have device filters specified, all devices will be visible to them. + """ + + def __init__(self): + # `_device_filters` is a dict mapping job names to job device filters. + # Job device filters further maps task IDs to task device filters. + # Task device filters are a list of strings, each one is a device filter. + self._device_filters = {} + + # Serialized protobuf for cluster device filters. + self._cluster_device_filters = None + + def set_device_filters(self, job_name, task_index, device_filters): + """Set the device filters for given job name and task id.""" + assert all(isinstance(df, str) for df in device_filters) + self._device_filters.setdefault(job_name, {}) + self._device_filters[job_name][task_index] = [df for df in device_filters] + # Due to updates in data, invalidate the serialized proto cache. + self._cluster_device_filters = None + + def _as_cluster_device_filters(self): + """Returns a serialized protobuf of cluster device filters.""" + if self._cluster_device_filters: + return self._cluster_device_filters + + self._make_cluster_device_filters() + return self._cluster_device_filters + + def _make_cluster_device_filters(self): + """Creates `ClusterDeviceFilters` proto based on the `_device_filters`. + + Raises: + TypeError: If `_device_filters` is not a dictionary mapping strings to + a map of task indices and device filters. + """ + self._cluster_device_filters = device_filters_pb2.ClusterDeviceFilters() + + # Sort by job_name to produce deterministic protobufs. + for job_name, tasks in sorted(self._device_filters.items()): + try: + job_name = compat.as_bytes(job_name) + except TypeError: + raise TypeError("Job name %r must be bytes or unicode" % job_name) + + jdf = self._cluster_device_filters.jobs.add() + jdf.name = job_name + + for i, task_device_filters in sorted(tasks.items()): + for tdf in task_device_filters: + try: + tdf = compat.as_bytes(tdf) + except TypeError: + raise TypeError("Device filter %r must be bytes or unicode" % tdf) + jdf.tasks[i].device_filters.append(tdf) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.-cluster-device-filters.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.-cluster-device-filters.pbtxt new file mode 100644 index 00000000000..8dc3b00f782 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.-cluster-device-filters.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.config.experimental.ClusterDeviceFilters" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_device_filters" + argspec: "args=[\'self\', \'job_name\', \'task_index\', \'device_filters\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt index f4b8bd63b0a..b8f92b30099 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.config.experimental" tf_module { + member { + name: "ClusterDeviceFilters" + mtype: "" + } member { name: "VirtualDeviceConfiguration" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt index b9d1004803f..7876afae9a4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt @@ -26,7 +26,7 @@ tf_module { } member_method { name: "experimental_connect_to_cluster" - argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], " + argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\', \'cluster_device_filters\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\', \'None\'], " } member_method { name: "experimental_connect_to_host" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-server-def.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-server-def.pbtxt index 03a3a195311..641ea210601 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.-server-def.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-server-def.pbtxt @@ -40,5 +40,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT32 } + field { + name: "cluster_device_filters" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ClusterDeviceFilters" + } } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.-cluster-device-filters.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.-cluster-device-filters.pbtxt new file mode 100644 index 00000000000..8dc3b00f782 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.-cluster-device-filters.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.config.experimental.ClusterDeviceFilters" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_device_filters" + argspec: "args=[\'self\', \'job_name\', \'task_index\', \'device_filters\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt index f4b8bd63b0a..b8f92b30099 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.config.experimental" tf_module { + member { + name: "ClusterDeviceFilters" + mtype: "" + } member { name: "VirtualDeviceConfiguration" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt index b9d1004803f..7876afae9a4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt @@ -26,7 +26,7 @@ tf_module { } member_method { name: "experimental_connect_to_cluster" - argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], " + argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\', \'cluster_device_filters\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\', \'None\'], " } member_method { name: "experimental_connect_to_host" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-server-def.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-server-def.pbtxt index 03a3a195311..641ea210601 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.-server-def.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-server-def.pbtxt @@ -40,5 +40,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT32 } + field { + name: "cluster_device_filters" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ClusterDeviceFilters" + } } }