diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5aeabf159af..67da9c4f0a4 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/platform.h" // NOLINT #include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/core/protobuf/device_filters.pb.h" #include "tensorflow/core/util/device_name_utils.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -265,9 +266,9 @@ tensorflow::Status GetReplacedFromExistingWorkers( } tensorflow::Status CreateRemoteContexts( - const std::vector& remote_workers, tensorflow::uint64 context_id, - tensorflow::uint64 context_view_id, int keep_alive_secs, - const tensorflow::ServerDef& server_def, + TFE_Context* ctx, const std::vector& remote_workers, + tensorflow::uint64 context_id, tensorflow::uint64 context_view_id, + int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, const bool lazy_copy_remote_function_inputs, const tensorflow::eager::CreateContextRequest& base_request) { @@ -296,7 +297,7 @@ tensorflow::Status CreateRemoteContexts( continue; } - tensorflow::eager::CreateContextRequest request(base_request); + tensorflow::eager::CreateContextRequest request; tensorflow::eager::CreateContextResponse* response = new tensorflow::eager::CreateContextResponse(); request.set_context_id(context_id); @@ -304,6 +305,21 @@ tensorflow::Status CreateRemoteContexts( *request.mutable_server_def() = server_def; request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); + request.mutable_server_def()->mutable_default_session_config()->MergeFrom( + server_def.default_session_config()); + + std::vector filtered_device_mask; + ctx->context->FilterDevicesForRemoteWorkers( + remote_worker, base_request.cluster_device_attributes(), + &filtered_device_mask); + DCHECK_EQ(filtered_device_mask.size(), + base_request.cluster_device_attributes_size()); + for (int i = 0; i < filtered_device_mask.size(); i++) { + if (filtered_device_mask[i]) { + const auto& da = base_request.cluster_device_attributes(i); + *request.add_cluster_device_attributes() = da; + } + } request.set_async(async); request.set_keep_alive_secs(keep_alive_secs); request.set_lazy_copy_remote_function_inputs( @@ -325,13 +341,34 @@ tensorflow::Status CreateRemoteContexts( } tensorflow::Status UpdateRemoteContexts( - const std::vector& remote_workers, tensorflow::uint64 context_id, + TFE_Context* ctx, const std::vector& remote_workers, + const std::vector& added_workers, + const std::vector& removed_workers, tensorflow::uint64 context_id, tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, const tensorflow::eager::CreateContextRequest& base_request) { int num_remote_workers = remote_workers.size(); tensorflow::BlockingCounter counter(num_remote_workers); std::vector statuses(num_remote_workers); + + int cluster_device_count = base_request.cluster_device_attributes_size(); + std::unordered_set added_or_removed(added_workers.begin(), + added_workers.end()); + std::copy(removed_workers.begin(), removed_workers.end(), + std::inserter(added_or_removed, added_or_removed.end())); + // Whether each device is in the updated (added or removed) workers + std::vector device_added_or_removed(cluster_device_count); + for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) { + const auto& da = base_request.cluster_device_attributes().at(i); + tensorflow::DeviceNameUtils::ParsedName pn; + tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn); + string task_name; + tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name); + if (added_or_removed.find(task_name) != added_or_removed.end()) { + device_added_or_removed[i] = true; + } + } + for (int i = 0; i < num_remote_workers; i++) { const string& remote_worker = remote_workers[i]; tensorflow::DeviceNameUtils::ParsedName parsed_name; @@ -354,17 +391,42 @@ tensorflow::Status UpdateRemoteContexts( continue; } + std::vector filtered_device_mask; + ctx->context->FilterDevicesForRemoteWorkers( + remote_worker, base_request.cluster_device_attributes(), + &filtered_device_mask); + DCHECK_EQ(filtered_device_mask.size(), cluster_device_count); + + // If any of the devices that match the device filters are in the set of + // added or removed workers, we must send a complete UpdateContextRequest. + // Otherwise, only send a simple request to increment context view ID. + std::vector added_or_removed_filtered_devices(cluster_device_count); + std::transform(device_added_or_removed.begin(), + device_added_or_removed.end(), filtered_device_mask.begin(), + added_or_removed_filtered_devices.begin(), + std::logical_and()); + const bool full_update_request = + std::accumulate(added_or_removed_filtered_devices.begin(), + added_or_removed_filtered_devices.end(), false, + std::logical_or()); + tensorflow::eager::UpdateContextRequest request; auto* response = new tensorflow::eager::UpdateContextResponse(); - - *request.mutable_server_def() = server_def; - request.mutable_server_def()->set_job_name(parsed_name.job); - request.mutable_server_def()->set_task_index(parsed_name.task); - for (const auto& da : base_request.cluster_device_attributes()) { - *request.add_cluster_device_attributes() = da; - } request.set_context_id(context_id); request.set_context_view_id(context_view_id); + if (full_update_request) { + *request.mutable_server_def() = server_def; + request.mutable_server_def()->set_job_name(parsed_name.job); + request.mutable_server_def()->set_task_index(parsed_name.task); + request.mutable_server_def()->mutable_default_session_config()->MergeFrom( + server_def.default_session_config()); + for (int i = 0; i < cluster_device_count; i++) { + if (filtered_device_mask[i]) { + const auto& da = base_request.cluster_device_attributes(i); + *request.add_cluster_device_attributes() = da; + } + } + } eager_client->UpdateContextAsync( &request, response, @@ -525,15 +587,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( for (const auto& da : local_device_attributes) { *base_request.add_cluster_device_attributes() = da; } - base_request.mutable_server_def() - ->mutable_default_session_config() - ->MergeFrom(server_def.default_session_config()); // Initialize remote eager workers. // TODO(b/138847548) Create remote eager contexts in async mode by default. if (reset_context) { LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, context_id, context_view_id, keep_alive_secs, + ctx, remote_workers, context_id, context_view_id, keep_alive_secs, server_def, remote_eager_workers.get(), context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), base_request)); } else { @@ -543,7 +602,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // we must set their context_view_id to the existing master's // context_view_id + 1. LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - added_workers, context_id, context_view_id + 1, keep_alive_secs, + ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, server_def, remote_eager_workers.get(), context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), base_request)); if (!existing_workers.empty()) { @@ -553,8 +612,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( } } LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts( - existing_workers, context_id, context_view_id + 1, server_def, - remote_eager_workers.get(), base_request)); + ctx, existing_workers, added_workers, removed_workers, context_id, + context_view_id + 1, server_def, remote_eager_workers.get(), + base_request)); } } @@ -709,6 +769,22 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, "Invalid tensorflow.ServerDef protocol buffer"); return; } + if (server_def.has_cluster_device_filters()) { + const auto& cdf = server_def.cluster_device_filters(); + for (const auto& jdf : cdf.jobs()) { + const string& remote_prefix = "/job:" + jdf.name() + "/task:"; + for (const auto& tdf : jdf.tasks()) { + const int32_t task_index = tdf.first; + std::vector device_filters(tdf.second.device_filters_size()); + for (int i = 0; i < tdf.second.device_filters_size(); i++) { + device_filters[i] = tdf.second.device_filters(i); + } + const string remote_worker = remote_prefix + std::to_string(task_index); + status->status = + ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters); + } + } + } status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx, /*reset_context=*/true); #endif // !IS_MOBILE_PLATFORM @@ -733,6 +809,11 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, status->status = tensorflow::errors::InvalidArgument( "Trying to update a context with invalid context id."); } + if (server_def.has_cluster_device_filters()) { + LOG(WARNING) << "Device filters can only be specified when initializing " + "the cluster. Any changes in device filters are ignored " + "when updating the server def."; + } // TODO(haoyuzhang): Check server_def compatibility before the update status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx, /*reset_context=*/false); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b2354839021..634670dd0ac 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -186,6 +186,7 @@ COMMON_PROTO_SRCS = [ "protobuf/config.proto", "protobuf/cluster.proto", "protobuf/debug.proto", + "protobuf/device_filters.proto", "protobuf/device_properties.proto", "protobuf/graph_debug_info.proto", "protobuf/queue_runner.proto", diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 301b75dfa68..9b0aa3a2c31 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -778,6 +778,11 @@ uint64 EagerContext::GetContextViewId() { return context_view_id_; } +void EagerContext::IncrementContextViewId() { + mutex_lock l(remote_state_mu_); + context_view_id_ += 1; +} + // Set collective ops related state in the context. Passing nullptr to // `new_server` will reuse the existing GRPC server in context. Status EagerContext::StoreCollectiveOpsServer( @@ -820,6 +825,86 @@ Status EagerContext::StoreCollectiveOpsServer( return Status::OK(); } +Status EagerContext::SetRemoteDeviceFilters( + const string& remote_worker, const std::vector& device_filters) { + // Get fully specified task name for remote worker + string remote_worker_task_name; + DeviceNameUtils::ParsedName pw; + if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) { + return tensorflow::errors::InvalidArgument( + "Remote worker task name is invalid ", remote_worker); + } + // Force set a replica as the key in cluster device filters map. I.e., if the + // remote worker is `/job:worker/task:0` it then becomes + // `/job:worker/replica:0/task:0`. + pw.has_replica = true; + if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) { + return tensorflow::errors::InvalidArgument( + "Job name and task index must be specified for worker ", remote_worker); + } + + std::vector parsed_filters; + for (auto& filter : device_filters) { + DeviceNameUtils::ParsedName parsed_filter; + if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) { + parsed_filters.emplace_back(parsed_filter); + } else { + return tensorflow::errors::InvalidArgument("Invalid filter: ", filter); + } + } + + if (VLOG_IS_ON(1)) { + VLOG(1) << "Setting device filters for " << remote_worker << ":"; + for (auto& filter : device_filters) { + VLOG(1) << " " << filter; + } + } + mutex_lock l(remote_state_mu_); + cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters); + return Status::OK(); +} + +void EagerContext::FilterDevicesForRemoteWorkers( + const string& remote_worker, + const protobuf::RepeatedPtrField& device_attrs, + std::vector* filtered_device_mask) { + filtered_device_mask->resize(device_attrs.size()); + std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false); + + tf_shared_lock l(remote_state_mu_); + auto it = cluster_device_filters_.find(remote_worker); + // If no filters were specified, all devices should be visible to the worker + if (it == cluster_device_filters_.end() || it->second.empty()) { + std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true); + return; + } + + const std::vector& parsed_filters = it->second; + DeviceNameUtils::ParsedName parsed_remote_worker; + DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker); + for (int i = 0; i < device_attrs.size(); i++) { + DeviceNameUtils::ParsedName pn; + DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn); + if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) { + // If this device is on the remote worker itself, it should be visible + // regardless of device filters + filtered_device_mask->at(i) = true; + continue; + } + for (const auto& pf : parsed_filters) { + if ((!pn.has_job || !pf.has_job || pn.job == pf.job) && + (!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) && + (!pn.has_task || !pf.has_task || pn.task == pf.task) && + (!pn.has_type || !pf.has_type || pn.type == pf.type) && + (!pn.has_id || !pf.has_id || pn.id == pf.id)) { + // Found a match, make it visible, stop processing more device filters + filtered_device_mask->at(i) = true; + break; + } + } + } +} + Status EagerContext::InitializeRemoteMaster( std::unique_ptr server, WorkerEnv* worker_env, std::shared_ptr worker_session, diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index de573410442..44f287d0ad8 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -292,6 +292,7 @@ class EagerContext : public core::RefCounted { uint64 GetContextId(); uint64 GetContextViewId(); + void IncrementContextViewId(); // TODO(nareshmodi): Encapsulate remote state into a separate // class/struct. @@ -358,6 +359,24 @@ class EagerContext : public core::RefCounted { std::unique_ptr new_server, DeviceMgr* device_mgr, CollectiveExecutorMgrInterface* rpc_collective_executor_mgr); + // For the specified remote worker, preprocess and set its device filters. + Status SetRemoteDeviceFilters(const string& remote_worker, + const std::vector& device_filters); + + // For the specified remote worker, apply the stored device filters to the + // list of device attributes following these rules: + // (1) if the remote worker does not have device filters, all devices are + // visible to the worker; + // (2) if the device is on the remote worker, then it is visible; + // (3) if the device matches at least one device filter, then it is visible. + // The result is saved as a boolean vector of the same length (i.e., + // filtered_device_mask) indicating whether each of the devices is visible to + // the remote worker. + void FilterDevicesForRemoteWorkers( + const string& remote_worker, + const protobuf::RepeatedPtrField& device_attrs, + std::vector* filtered_device_mask); + // TODO(fishx): Remove the custom deleter once we remove forward declaration. const std::unique_ptr>& @@ -568,6 +587,11 @@ class EagerContext : public core::RefCounted { std::unique_ptr> remote_mgr_; bool is_master_ GUARDED_BY(remote_state_mu_); + + // Maps from a remote worker to a list of parsed device filters. + std::unordered_map> + cluster_device_filters_ GUARDED_BY(remote_state_mu_); + #endif // IS_MOBILE_PLATFORM // For a multi device function, the target device of each input is unknown diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 90237f85849..d57aeb77b22 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -119,6 +119,13 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, auto* r = env_->rendezvous_mgr->Find(request->context_id()); auto session_name = tensorflow::strings::StrCat("eager_", request->context_id()); + if (VLOG_IS_ON(2)) { + VLOG(2) << "Creating context on /job:" << request->server_def().job_name() + << "/task:" << request->server_def().task_index(); + for (const auto& da : request->cluster_device_attributes()) { + VLOG(2) << " " << da.name(); + } + } TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession( session_name, request->server_def(), request->cluster_device_attributes(), true)); @@ -229,6 +236,16 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, " but received update request at view #", request->context_view_id(), ". View id should only be continuously incremented."); } + if (request->cluster_device_attributes_size() == 0) { + // In this case, the client indicates that the updated `server_def` and + // device info is irrelevant to this worker, since it is not connected to + // the updated ones (likely due to device filter settings). The worker + // simply needs to update view ID and does not update other internal state. + ctx->IncrementContextViewId(); + VLOG(1) << "Processing simplified UpdateContextRequest on " + << ctx->HostCPU()->name(); + return Status::OK(); + } // TODO(b/143914772): Potential memory leak if rendezvous has pending // tensors for removed / replaced workers. diff --git a/tensorflow/core/protobuf/device_filters.proto b/tensorflow/core/protobuf/device_filters.proto new file mode 100644 index 00000000000..0aa38379a2a --- /dev/null +++ b/tensorflow/core/protobuf/device_filters.proto @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; + +option cc_enable_arenas = true; +option java_outer_classname = "DeviceFiltersProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +// This file contains protos to be used when defining a TensorFlow +// cluster. +// +// Configure device filters for remote tasks in the cluster. When associated +// with a ClusterDef in setting up the cluster, a remote task will ignore all +// devices which do not match any of its filters. Device filters must be +// configured at the cluster startup, and cannot be updated once the cluster is +// up and running. +// +// EXAMPLES +// -------- +// +// A two-job cluster with the following ClusterDef: +// +// Cluster: +// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' } +// tasks { key: 1 value: 'worker2:2222' } } +// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } +// tasks { key: 1 value: 'ps1:2222' } } +// +// Set device filters to isolate worker tasks: +// +// ClusterDeviceFilters: +// job { name: 'worker' tasks { key: 0 +// value: device_filter '/job:ps' +// device_filter '/job:worker/task:0' } +// tasks { key: 1 +// value: device_filter '/job:ps' +// device_filter '/job:worker/task:1' } } + +// Defines the device filters for a remote task. +message TaskDeviceFilters { + repeated string device_filters = 1; +} + +// Defines the device filters for tasks in a job. +message JobDeviceFilters { + // The name of this job. + string name = 1; + + // Mapping from task ID to task device filters. + map tasks = 2; +} + +// Defines the device filters for jobs in a cluster. +message ClusterDeviceFilters { + repeated JobDeviceFilters jobs = 1; +} diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 4335d87309a..967df44d3dc 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -109,7 +109,10 @@ message UpdateContextRequest { // Identifies the full cluster, and this particular worker's position within. ServerDef server_def = 1; - // Device attributes in the cluster + // Device attributes in the cluster. + // If this field is empty, it indicates that this is a simple update request + // that only increments the cluster view ID and does not require changes to + // the workers it connects to. repeated DeviceAttributes cluster_device_attributes = 2; // The ID of the context to be updated. A context with the specified ID must diff --git a/tensorflow/core/protobuf/tensorflow_server.proto b/tensorflow/core/protobuf/tensorflow_server.proto index 6ff902cbc97..6b3010ab37e 100644 --- a/tensorflow/core/protobuf/tensorflow_server.proto +++ b/tensorflow/core/protobuf/tensorflow_server.proto @@ -15,14 +15,17 @@ limitations under the License. syntax = "proto3"; -import "tensorflow/core/protobuf/config.proto"; -import "tensorflow/core/protobuf/cluster.proto"; - package tensorflow; + +import "tensorflow/core/protobuf/cluster.proto"; +import "tensorflow/core/protobuf/config.proto"; +import "tensorflow/core/protobuf/device_filters.proto"; + option cc_enable_arenas = true; option java_outer_classname = "ServerProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.distruntime"; + option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf"; // Defines the configuration of a single TensorFlow server. message ServerDef { @@ -51,4 +54,8 @@ message ServerDef { // The server port. If not set, then we identify the port from the job_name. int32 port = 6; + + // Device filters for remote tasks in the cluster. + // NOTE: This is an experimental feature and only effective in TensorFlow 2.x. + ClusterDeviceFilters cluster_device_filters = 7; } diff --git a/tensorflow/python/eager/remote.py b/tensorflow/python/eager/remote.py index 276f2de9842..dcb22c17ff6 100644 --- a/tensorflow/python/eager/remote.py +++ b/tensorflow/python/eager/remote.py @@ -83,7 +83,8 @@ def connect_to_cluster(cluster_spec_or_resolver, job_name="localhost", task_index=0, protocol=None, - make_master_device_default=True): + make_master_device_default=True, + cluster_device_filters=None): """Connects to the given cluster. Will make devices on the cluster available to use. Note that calling this more @@ -93,6 +94,30 @@ def connect_to_cluster(cluster_spec_or_resolver, If the given local job name is not present in the cluster specification, it will be automatically added, using an unused port on the localhost. + Device filters can be specified to isolate groups of remote tasks to avoid + undesired accesses between workers. Workers accessing resources or launching + ops / functions on filtered remote devices will result in errors (unknown + devices). For any remote task, if no device filter is present, all cluster + devices will be visible; if any device filter is specified, it can only + see devices matching at least one filter. Devices on the task itself are + always visible. Device filters can be particially specified. + + For example, for a cluster set up for parameter server training, the following + device filters might be specified: + + ```python + cdf = tf.config.experimental.ClusterDeviceFilters() + # For any worker, only the devices on PS nodes and itself are visible + for i in range(num_workers): + cdf.set_device_filters('worker', i, ['/job:ps']) + # Similarly for any ps, only the devices on workers and itself are visible + for i in range(num_ps): + cdf.set_device_filters('ps', i, ['/job:worker']) + + tf.config.experimental_connect_to_cluster(cluster_def, + cluster_device_filters=cdf) + ``` + Args: cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing the cluster. @@ -105,6 +130,9 @@ def connect_to_cluster(cluster_spec_or_resolver, master becomes the default device to run ops. It won't do anything if a cluster spec is passed. Will throw an error if the caller is currently already in some device scope. + cluster_device_filters: an instance of + `tf.train.experimental/ClusterDeviceFilters` that specify device filters + to the remote tasks in cluster. """ if not context.executing_eagerly(): raise ValueError( @@ -125,6 +153,13 @@ def connect_to_cluster(cluster_spec_or_resolver, "`ClusterResolver`.") cluster_def = copy.deepcopy(cluster_spec.as_cluster_def()) + if cluster_device_filters: + if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters): + cluster_device_filters = copy.deepcopy( + cluster_device_filters._as_cluster_device_filters()) # pylint: disable=protected-access + else: + raise ValueError("`cluster_device_filters` must be an instance of " + "`tf.train.experimental.ClusterDeviceFilters`.") # Automatically add local job, if not part of the cluster spec. if job_name not in cluster_spec.jobs: @@ -140,7 +175,8 @@ def connect_to_cluster(cluster_spec_or_resolver, job_name=job_name, task_index=task_index, protocol=protocol, - default_session_config=context.context().config) + default_session_config=context.context().config, + cluster_device_filters=cluster_device_filters) if context.get_server_def() is None: context.set_server_def(server_def) diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index acafbb2626d..275da732c03 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -290,13 +290,10 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase): def setUp(self): super(MultiJobsTest, self).setUp() - workers, ps = test_util.create_local_cluster(2, 1) + workers, ps = test_util.create_local_cluster(num_workers=2, num_ps=2) cluster = { - 'my_worker': [ - _strip_prefix(workers[0].target, _GRPC_PREFIX), - _strip_prefix(workers[1].target, _GRPC_PREFIX), - ], - 'my_ps': [_strip_prefix(ps[0].target, _GRPC_PREFIX)], + 'my_worker': [_strip_prefix(t.target, _GRPC_PREFIX) for t in workers], + 'my_ps': [_strip_prefix(t.target, _GRPC_PREFIX) for t in ps], } self._cluster = server_lib.ClusterSpec(cluster) self._cluster_resolver = SimpleClusterResolver( @@ -330,6 +327,53 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase): with ops.device('/job:my_worker/task:1/device:CPU:0'): self.assertAllEqual(worker_fn(), 8) + @test_util.eager_lazy_remote_copy_on_and_off + def testSimpleParameterServerWithDeviceFilters(self): + cluster_device_filters = server_lib.ClusterDeviceFilters() + for i in range(2): + cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps']) + cluster_device_filters.set_device_filters('my_ps', i, ['/job:my_worker']) + remote.connect_to_cluster( + self._cluster, cluster_device_filters=cluster_device_filters) + + with ops.device('/job:my_ps/task:0/device:CPU:0'): + v1 = variables.Variable(initial_value=0) + with ops.device('/job:my_ps/task:1/device:CPU:0'): + v2 = variables.Variable(initial_value=10) + + @def_function.function + def worker_fn(): + v1.assign_add(1) + v2.assign_sub(2) + return v1.read_value() + v2.read_value() + + with ops.device('/job:my_worker/task:0/device:CPU:0'): + self.assertAllEqual(worker_fn(), 9) + with ops.device('/job:my_worker/task:1/device:CPU:0'): + self.assertAllEqual(worker_fn(), 8) + + # The following remote call would fail because the ps nodes cannot see each + # other due to the device filters. + with self.assertRaises(errors.InvalidArgumentError) as cm: + with ops.device('/job:my_ps/task:0/device:CPU:0'): + worker_fn().numpy() + self.assertIn('/job:my_ps/replica:0/task:1/device:CPU:0 unknown device', + cm.exception.message) + + with self.assertRaises(errors.InvalidArgumentError) as cm: + with ops.device('/job:my_ps/task:1/device:CPU:0'): + worker_fn().numpy() + self.assertIn('/job:my_ps/replica:0/task:0/device:CPU:0 unknown device', + cm.exception.message) + + with ops.device('/job:my_worker/task:0/device:CPU:0'): + self.assertAllEqual(worker_fn(), 7) + with ops.device('/job:my_worker/task:1/device:CPU:0'): + self.assertAllEqual(worker_fn(), 6) + # Explicitly delete variables to avoid triggering errors when being GC'ed in + # subsequent tests. + del v1, v2 + @test_util.eager_lazy_remote_copy_on_and_off def testConnectWithClusterResolver(self): remote.connect_to_cluster(self._cluster_resolver) diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py index 259a9a16c98..b3e840f2f5a 100644 --- a/tensorflow/python/training/server_lib.py +++ b/tensorflow/python/training/server_lib.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.core.protobuf import device_filters_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.framework import errors @@ -488,3 +489,85 @@ class ClusterSpec(object): raise TypeError("Task address %r must be bytes or unicode" % task_address) job_def.tasks[i] = task_address + + +@tf_export("config.experimental.ClusterDeviceFilters") +class ClusterDeviceFilters(object): + """Represent a collection of device filters for the remote workers in cluster. + + NOTE: this is an experimental API and subject to changes. + + Set device filters for selective jobs and tasks. For each remote worker, the + device filters are a list of strings. When any filters are present, the remote + worker will ignore all devices which do not match any of its filters. Each + filter can be partially specified, e.g. "/job:ps", "/job:worker/replica:3", + etc. Note that a device is always visible to the worker it is located on. + + For example, to set the device filters for a parameter server cluster: + + ```python + cdf = tf.config.experimental.ClusterDeviceFilters() + for i in range(num_workers): + cdf.set_device_filters('worker', i, ['/job:ps']) + for i in range(num_ps): + cdf.set_device_filters('ps', i, ['/job:worker']) + + tf.config.experimental_connect_to_cluster(cluster_def, + cluster_device_filters=cdf) + ``` + + The device filters can be partically specified. For remote tasks that do not + have device filters specified, all devices will be visible to them. + """ + + def __init__(self): + # `_device_filters` is a dict mapping job names to job device filters. + # Job device filters further maps task IDs to task device filters. + # Task device filters are a list of strings, each one is a device filter. + self._device_filters = {} + + # Serialized protobuf for cluster device filters. + self._cluster_device_filters = None + + def set_device_filters(self, job_name, task_index, device_filters): + """Set the device filters for given job name and task id.""" + assert all(isinstance(df, str) for df in device_filters) + self._device_filters.setdefault(job_name, {}) + self._device_filters[job_name][task_index] = [df for df in device_filters] + # Due to updates in data, invalidate the serialized proto cache. + self._cluster_device_filters = None + + def _as_cluster_device_filters(self): + """Returns a serialized protobuf of cluster device filters.""" + if self._cluster_device_filters: + return self._cluster_device_filters + + self._make_cluster_device_filters() + return self._cluster_device_filters + + def _make_cluster_device_filters(self): + """Creates `ClusterDeviceFilters` proto based on the `_device_filters`. + + Raises: + TypeError: If `_device_filters` is not a dictionary mapping strings to + a map of task indices and device filters. + """ + self._cluster_device_filters = device_filters_pb2.ClusterDeviceFilters() + + # Sort by job_name to produce deterministic protobufs. + for job_name, tasks in sorted(self._device_filters.items()): + try: + job_name = compat.as_bytes(job_name) + except TypeError: + raise TypeError("Job name %r must be bytes or unicode" % job_name) + + jdf = self._cluster_device_filters.jobs.add() + jdf.name = job_name + + for i, task_device_filters in sorted(tasks.items()): + for tdf in task_device_filters: + try: + tdf = compat.as_bytes(tdf) + except TypeError: + raise TypeError("Device filter %r must be bytes or unicode" % tdf) + jdf.tasks[i].device_filters.append(tdf) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.-cluster-device-filters.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.-cluster-device-filters.pbtxt new file mode 100644 index 00000000000..8dc3b00f782 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.-cluster-device-filters.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.config.experimental.ClusterDeviceFilters" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_device_filters" + argspec: "args=[\'self\', \'job_name\', \'task_index\', \'device_filters\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt index f4b8bd63b0a..b8f92b30099 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.config.experimental" tf_module { + member { + name: "ClusterDeviceFilters" + mtype: "" + } member { name: "VirtualDeviceConfiguration" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt index b9d1004803f..7876afae9a4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt @@ -26,7 +26,7 @@ tf_module { } member_method { name: "experimental_connect_to_cluster" - argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], " + argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\', \'cluster_device_filters\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\', \'None\'], " } member_method { name: "experimental_connect_to_host" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-server-def.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-server-def.pbtxt index 03a3a195311..641ea210601 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.-server-def.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-server-def.pbtxt @@ -40,5 +40,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT32 } + field { + name: "cluster_device_filters" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ClusterDeviceFilters" + } } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.-cluster-device-filters.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.-cluster-device-filters.pbtxt new file mode 100644 index 00000000000..8dc3b00f782 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.-cluster-device-filters.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.config.experimental.ClusterDeviceFilters" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_device_filters" + argspec: "args=[\'self\', \'job_name\', \'task_index\', \'device_filters\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt index f4b8bd63b0a..b8f92b30099 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.config.experimental" tf_module { + member { + name: "ClusterDeviceFilters" + mtype: "" + } member { name: "VirtualDeviceConfiguration" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt index b9d1004803f..7876afae9a4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt @@ -26,7 +26,7 @@ tf_module { } member_method { name: "experimental_connect_to_cluster" - argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], " + argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\', \'cluster_device_filters\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\', \'None\'], " } member_method { name: "experimental_connect_to_host" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-server-def.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-server-def.pbtxt index 03a3a195311..641ea210601 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.-server-def.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-server-def.pbtxt @@ -40,5 +40,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT32 } + field { + name: "cluster_device_filters" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ClusterDeviceFilters" + } } }