Allow configuring device filters in eager cluster to isolate workers.
When device filters are set for remote workers, they will only have access to cluster devices that match the filters. If we need to update the cluster, send complete update request to workers that match the device filters, and only send simplified view ID increment request to other workers. PiperOrigin-RevId: 292240301 Change-Id: Icb221c65231a248369f4a341f954b365eb1a42a9
This commit is contained in:
parent
a48e6bb7d8
commit
39fd4e7c7b
@ -44,6 +44,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -265,9 +266,9 @@ tensorflow::Status GetReplacedFromExistingWorkers(
|
||||
}
|
||||
|
||||
tensorflow::Status CreateRemoteContexts(
|
||||
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, int keep_alive_secs,
|
||||
const tensorflow::ServerDef& server_def,
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const bool lazy_copy_remote_function_inputs,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
@ -296,7 +297,7 @@ tensorflow::Status CreateRemoteContexts(
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::eager::CreateContextRequest request(base_request);
|
||||
tensorflow::eager::CreateContextRequest request;
|
||||
tensorflow::eager::CreateContextResponse* response =
|
||||
new tensorflow::eager::CreateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
@ -304,6 +305,21 @@ tensorflow::Status CreateRemoteContexts(
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
ctx->context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(),
|
||||
base_request.cluster_device_attributes_size());
|
||||
for (int i = 0; i < filtered_device_mask.size(); i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
request.set_lazy_copy_remote_function_inputs(
|
||||
@ -325,13 +341,34 @@ tensorflow::Status CreateRemoteContexts(
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateRemoteContexts(
|
||||
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
const std::vector<string>& added_workers,
|
||||
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
|
||||
int cluster_device_count = base_request.cluster_device_attributes_size();
|
||||
std::unordered_set<string> added_or_removed(added_workers.begin(),
|
||||
added_workers.end());
|
||||
std::copy(removed_workers.begin(), removed_workers.end(),
|
||||
std::inserter(added_or_removed, added_or_removed.end()));
|
||||
// Whether each device is in the updated (added or removed) workers
|
||||
std::vector<bool> device_added_or_removed(cluster_device_count);
|
||||
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
|
||||
const auto& da = base_request.cluster_device_attributes().at(i);
|
||||
tensorflow::DeviceNameUtils::ParsedName pn;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
|
||||
string task_name;
|
||||
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
|
||||
if (added_or_removed.find(task_name) != added_or_removed.end()) {
|
||||
device_added_or_removed[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
@ -354,17 +391,42 @@ tensorflow::Status UpdateRemoteContexts(
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
ctx->context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
|
||||
|
||||
// If any of the devices that match the device filters are in the set of
|
||||
// added or removed workers, we must send a complete UpdateContextRequest.
|
||||
// Otherwise, only send a simple request to increment context view ID.
|
||||
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
|
||||
std::transform(device_added_or_removed.begin(),
|
||||
device_added_or_removed.end(), filtered_device_mask.begin(),
|
||||
added_or_removed_filtered_devices.begin(),
|
||||
std::logical_and<bool>());
|
||||
const bool full_update_request =
|
||||
std::accumulate(added_or_removed_filtered_devices.begin(),
|
||||
added_or_removed_filtered_devices.end(), false,
|
||||
std::logical_or<bool>());
|
||||
|
||||
tensorflow::eager::UpdateContextRequest request;
|
||||
auto* response = new tensorflow::eager::UpdateContextResponse();
|
||||
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
for (const auto& da : base_request.cluster_device_attributes()) {
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
if (full_update_request) {
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
for (int i = 0; i < cluster_device_count; i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eager_client->UpdateContextAsync(
|
||||
&request, response,
|
||||
@ -525,15 +587,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
for (const auto& da : local_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
base_request.mutable_server_def()
|
||||
->mutable_default_session_config()
|
||||
->MergeFrom(server_def.default_session_config());
|
||||
|
||||
// Initialize remote eager workers.
|
||||
// TODO(b/138847548) Create remote eager contexts in async mode by default.
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
} else {
|
||||
@ -543,7 +602,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// we must set their context_view_id to the existing master's
|
||||
// context_view_id + 1.
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
if (!existing_workers.empty()) {
|
||||
@ -553,8 +612,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
}
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
|
||||
existing_workers, context_id, context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), base_request));
|
||||
ctx, existing_workers, added_workers, removed_workers, context_id,
|
||||
context_view_id + 1, server_def, remote_eager_workers.get(),
|
||||
base_request));
|
||||
}
|
||||
}
|
||||
|
||||
@ -709,6 +769,22 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
return;
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
|
||||
device_filters[i] = tdf.second.device_filters(i);
|
||||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
status->status =
|
||||
ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
}
|
||||
}
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/true);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
@ -733,6 +809,11 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to update a context with invalid context id.");
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
LOG(WARNING) << "Device filters can only be specified when initializing "
|
||||
"the cluster. Any changes in device filters are ignored "
|
||||
"when updating the server def.";
|
||||
}
|
||||
// TODO(haoyuzhang): Check server_def compatibility before the update
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/false);
|
||||
|
@ -186,6 +186,7 @@ COMMON_PROTO_SRCS = [
|
||||
"protobuf/config.proto",
|
||||
"protobuf/cluster.proto",
|
||||
"protobuf/debug.proto",
|
||||
"protobuf/device_filters.proto",
|
||||
"protobuf/device_properties.proto",
|
||||
"protobuf/graph_debug_info.proto",
|
||||
"protobuf/queue_runner.proto",
|
||||
|
@ -778,6 +778,11 @@ uint64 EagerContext::GetContextViewId() {
|
||||
return context_view_id_;
|
||||
}
|
||||
|
||||
void EagerContext::IncrementContextViewId() {
|
||||
mutex_lock l(remote_state_mu_);
|
||||
context_view_id_ += 1;
|
||||
}
|
||||
|
||||
// Set collective ops related state in the context. Passing nullptr to
|
||||
// `new_server` will reuse the existing GRPC server in context.
|
||||
Status EagerContext::StoreCollectiveOpsServer(
|
||||
@ -820,6 +825,86 @@ Status EagerContext::StoreCollectiveOpsServer(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerContext::SetRemoteDeviceFilters(
|
||||
const string& remote_worker, const std::vector<string>& device_filters) {
|
||||
// Get fully specified task name for remote worker
|
||||
string remote_worker_task_name;
|
||||
DeviceNameUtils::ParsedName pw;
|
||||
if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Remote worker task name is invalid ", remote_worker);
|
||||
}
|
||||
// Force set a replica as the key in cluster device filters map. I.e., if the
|
||||
// remote worker is `/job:worker/task:0` it then becomes
|
||||
// `/job:worker/replica:0/task:0`.
|
||||
pw.has_replica = true;
|
||||
if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Job name and task index must be specified for worker ", remote_worker);
|
||||
}
|
||||
|
||||
std::vector<DeviceNameUtils::ParsedName> parsed_filters;
|
||||
for (auto& filter : device_filters) {
|
||||
DeviceNameUtils::ParsedName parsed_filter;
|
||||
if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) {
|
||||
parsed_filters.emplace_back(parsed_filter);
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument("Invalid filter: ", filter);
|
||||
}
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "Setting device filters for " << remote_worker << ":";
|
||||
for (auto& filter : device_filters) {
|
||||
VLOG(1) << " " << filter;
|
||||
}
|
||||
}
|
||||
mutex_lock l(remote_state_mu_);
|
||||
cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void EagerContext::FilterDevicesForRemoteWorkers(
|
||||
const string& remote_worker,
|
||||
const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
|
||||
std::vector<bool>* filtered_device_mask) {
|
||||
filtered_device_mask->resize(device_attrs.size());
|
||||
std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false);
|
||||
|
||||
tf_shared_lock l(remote_state_mu_);
|
||||
auto it = cluster_device_filters_.find(remote_worker);
|
||||
// If no filters were specified, all devices should be visible to the worker
|
||||
if (it == cluster_device_filters_.end() || it->second.empty()) {
|
||||
std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true);
|
||||
return;
|
||||
}
|
||||
|
||||
const std::vector<DeviceNameUtils::ParsedName>& parsed_filters = it->second;
|
||||
DeviceNameUtils::ParsedName parsed_remote_worker;
|
||||
DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker);
|
||||
for (int i = 0; i < device_attrs.size(); i++) {
|
||||
DeviceNameUtils::ParsedName pn;
|
||||
DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn);
|
||||
if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) {
|
||||
// If this device is on the remote worker itself, it should be visible
|
||||
// regardless of device filters
|
||||
filtered_device_mask->at(i) = true;
|
||||
continue;
|
||||
}
|
||||
for (const auto& pf : parsed_filters) {
|
||||
if ((!pn.has_job || !pf.has_job || pn.job == pf.job) &&
|
||||
(!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) &&
|
||||
(!pn.has_task || !pf.has_task || pn.task == pf.task) &&
|
||||
(!pn.has_type || !pf.has_type || pn.type == pf.type) &&
|
||||
(!pn.has_id || !pf.has_id || pn.id == pf.id)) {
|
||||
// Found a match, make it visible, stop processing more device filters
|
||||
filtered_device_mask->at(i) = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status EagerContext::InitializeRemoteMaster(
|
||||
std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
|
||||
std::shared_ptr<WorkerSession> worker_session,
|
||||
|
@ -292,6 +292,7 @@ class EagerContext : public core::RefCounted {
|
||||
|
||||
uint64 GetContextId();
|
||||
uint64 GetContextViewId();
|
||||
void IncrementContextViewId();
|
||||
|
||||
// TODO(nareshmodi): Encapsulate remote state into a separate
|
||||
// class/struct.
|
||||
@ -358,6 +359,24 @@ class EagerContext : public core::RefCounted {
|
||||
std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
|
||||
CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
|
||||
|
||||
// For the specified remote worker, preprocess and set its device filters.
|
||||
Status SetRemoteDeviceFilters(const string& remote_worker,
|
||||
const std::vector<string>& device_filters);
|
||||
|
||||
// For the specified remote worker, apply the stored device filters to the
|
||||
// list of device attributes following these rules:
|
||||
// (1) if the remote worker does not have device filters, all devices are
|
||||
// visible to the worker;
|
||||
// (2) if the device is on the remote worker, then it is visible;
|
||||
// (3) if the device matches at least one device filter, then it is visible.
|
||||
// The result is saved as a boolean vector of the same length (i.e.,
|
||||
// filtered_device_mask) indicating whether each of the devices is visible to
|
||||
// the remote worker.
|
||||
void FilterDevicesForRemoteWorkers(
|
||||
const string& remote_worker,
|
||||
const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
|
||||
std::vector<bool>* filtered_device_mask);
|
||||
|
||||
// TODO(fishx): Remove the custom deleter once we remove forward declaration.
|
||||
const std::unique_ptr<eager::RemoteMgr,
|
||||
std::function<void(eager::RemoteMgr*)>>&
|
||||
@ -568,6 +587,11 @@ class EagerContext : public core::RefCounted {
|
||||
std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
|
||||
remote_mgr_;
|
||||
bool is_master_ GUARDED_BY(remote_state_mu_);
|
||||
|
||||
// Maps from a remote worker to a list of parsed device filters.
|
||||
std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>>
|
||||
cluster_device_filters_ GUARDED_BY(remote_state_mu_);
|
||||
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
// For a multi device function, the target device of each input is unknown
|
||||
|
@ -119,6 +119,13 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
||||
auto* r = env_->rendezvous_mgr->Find(request->context_id());
|
||||
auto session_name =
|
||||
tensorflow::strings::StrCat("eager_", request->context_id());
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "Creating context on /job:" << request->server_def().job_name()
|
||||
<< "/task:" << request->server_def().task_index();
|
||||
for (const auto& da : request->cluster_device_attributes()) {
|
||||
VLOG(2) << " " << da.name();
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
|
||||
session_name, request->server_def(), request->cluster_device_attributes(),
|
||||
true));
|
||||
@ -229,6 +236,16 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request,
|
||||
" but received update request at view #", request->context_view_id(),
|
||||
". View id should only be continuously incremented.");
|
||||
}
|
||||
if (request->cluster_device_attributes_size() == 0) {
|
||||
// In this case, the client indicates that the updated `server_def` and
|
||||
// device info is irrelevant to this worker, since it is not connected to
|
||||
// the updated ones (likely due to device filter settings). The worker
|
||||
// simply needs to update view ID and does not update other internal state.
|
||||
ctx->IncrementContextViewId();
|
||||
VLOG(1) << "Processing simplified UpdateContextRequest on "
|
||||
<< ctx->HostCPU()->name();
|
||||
return Status::OK();
|
||||
}
|
||||
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
||||
// tensors for removed / replaced workers.
|
||||
|
||||
|
72
tensorflow/core/protobuf/device_filters.proto
Normal file
72
tensorflow/core/protobuf/device_filters.proto
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "DeviceFiltersProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.distruntime";
|
||||
|
||||
// This file contains protos to be used when defining a TensorFlow
|
||||
// cluster.
|
||||
//
|
||||
// Configure device filters for remote tasks in the cluster. When associated
|
||||
// with a ClusterDef in setting up the cluster, a remote task will ignore all
|
||||
// devices which do not match any of its filters. Device filters must be
|
||||
// configured at the cluster startup, and cannot be updated once the cluster is
|
||||
// up and running.
|
||||
//
|
||||
// EXAMPLES
|
||||
// --------
|
||||
//
|
||||
// A two-job cluster with the following ClusterDef:
|
||||
//
|
||||
// Cluster:
|
||||
// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' }
|
||||
// tasks { key: 1 value: 'worker2:2222' } }
|
||||
// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
|
||||
// tasks { key: 1 value: 'ps1:2222' } }
|
||||
//
|
||||
// Set device filters to isolate worker tasks:
|
||||
//
|
||||
// ClusterDeviceFilters:
|
||||
// job { name: 'worker' tasks { key: 0
|
||||
// value: device_filter '/job:ps'
|
||||
// device_filter '/job:worker/task:0' }
|
||||
// tasks { key: 1
|
||||
// value: device_filter '/job:ps'
|
||||
// device_filter '/job:worker/task:1' } }
|
||||
|
||||
// Defines the device filters for a remote task.
|
||||
message TaskDeviceFilters {
|
||||
repeated string device_filters = 1;
|
||||
}
|
||||
|
||||
// Defines the device filters for tasks in a job.
|
||||
message JobDeviceFilters {
|
||||
// The name of this job.
|
||||
string name = 1;
|
||||
|
||||
// Mapping from task ID to task device filters.
|
||||
map<int32, TaskDeviceFilters> tasks = 2;
|
||||
}
|
||||
|
||||
// Defines the device filters for jobs in a cluster.
|
||||
message ClusterDeviceFilters {
|
||||
repeated JobDeviceFilters jobs = 1;
|
||||
}
|
@ -109,7 +109,10 @@ message UpdateContextRequest {
|
||||
// Identifies the full cluster, and this particular worker's position within.
|
||||
ServerDef server_def = 1;
|
||||
|
||||
// Device attributes in the cluster
|
||||
// Device attributes in the cluster.
|
||||
// If this field is empty, it indicates that this is a simple update request
|
||||
// that only increments the cluster view ID and does not require changes to
|
||||
// the workers it connects to.
|
||||
repeated DeviceAttributes cluster_device_attributes = 2;
|
||||
|
||||
// The ID of the context to be updated. A context with the specified ID must
|
||||
|
@ -15,14 +15,17 @@ limitations under the License.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
import "tensorflow/core/protobuf/config.proto";
|
||||
import "tensorflow/core/protobuf/cluster.proto";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensorflow/core/protobuf/cluster.proto";
|
||||
import "tensorflow/core/protobuf/config.proto";
|
||||
import "tensorflow/core/protobuf/device_filters.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "ServerProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.distruntime";
|
||||
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf";
|
||||
// Defines the configuration of a single TensorFlow server.
|
||||
message ServerDef {
|
||||
@ -51,4 +54,8 @@ message ServerDef {
|
||||
|
||||
// The server port. If not set, then we identify the port from the job_name.
|
||||
int32 port = 6;
|
||||
|
||||
// Device filters for remote tasks in the cluster.
|
||||
// NOTE: This is an experimental feature and only effective in TensorFlow 2.x.
|
||||
ClusterDeviceFilters cluster_device_filters = 7;
|
||||
}
|
||||
|
@ -83,7 +83,8 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
||||
job_name="localhost",
|
||||
task_index=0,
|
||||
protocol=None,
|
||||
make_master_device_default=True):
|
||||
make_master_device_default=True,
|
||||
cluster_device_filters=None):
|
||||
"""Connects to the given cluster.
|
||||
|
||||
Will make devices on the cluster available to use. Note that calling this more
|
||||
@ -93,6 +94,30 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
||||
If the given local job name is not present in the cluster specification, it
|
||||
will be automatically added, using an unused port on the localhost.
|
||||
|
||||
Device filters can be specified to isolate groups of remote tasks to avoid
|
||||
undesired accesses between workers. Workers accessing resources or launching
|
||||
ops / functions on filtered remote devices will result in errors (unknown
|
||||
devices). For any remote task, if no device filter is present, all cluster
|
||||
devices will be visible; if any device filter is specified, it can only
|
||||
see devices matching at least one filter. Devices on the task itself are
|
||||
always visible. Device filters can be particially specified.
|
||||
|
||||
For example, for a cluster set up for parameter server training, the following
|
||||
device filters might be specified:
|
||||
|
||||
```python
|
||||
cdf = tf.config.experimental.ClusterDeviceFilters()
|
||||
# For any worker, only the devices on PS nodes and itself are visible
|
||||
for i in range(num_workers):
|
||||
cdf.set_device_filters('worker', i, ['/job:ps'])
|
||||
# Similarly for any ps, only the devices on workers and itself are visible
|
||||
for i in range(num_ps):
|
||||
cdf.set_device_filters('ps', i, ['/job:worker'])
|
||||
|
||||
tf.config.experimental_connect_to_cluster(cluster_def,
|
||||
cluster_device_filters=cdf)
|
||||
```
|
||||
|
||||
Args:
|
||||
cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
|
||||
the cluster.
|
||||
@ -105,6 +130,9 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
||||
master becomes the default device to run ops. It won't do anything if
|
||||
a cluster spec is passed. Will throw an error if the caller is currently
|
||||
already in some device scope.
|
||||
cluster_device_filters: an instance of
|
||||
`tf.train.experimental/ClusterDeviceFilters` that specify device filters
|
||||
to the remote tasks in cluster.
|
||||
"""
|
||||
if not context.executing_eagerly():
|
||||
raise ValueError(
|
||||
@ -125,6 +153,13 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
||||
"`ClusterResolver`.")
|
||||
|
||||
cluster_def = copy.deepcopy(cluster_spec.as_cluster_def())
|
||||
if cluster_device_filters:
|
||||
if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters):
|
||||
cluster_device_filters = copy.deepcopy(
|
||||
cluster_device_filters._as_cluster_device_filters()) # pylint: disable=protected-access
|
||||
else:
|
||||
raise ValueError("`cluster_device_filters` must be an instance of "
|
||||
"`tf.train.experimental.ClusterDeviceFilters`.")
|
||||
|
||||
# Automatically add local job, if not part of the cluster spec.
|
||||
if job_name not in cluster_spec.jobs:
|
||||
@ -140,7 +175,8 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
||||
job_name=job_name,
|
||||
task_index=task_index,
|
||||
protocol=protocol,
|
||||
default_session_config=context.context().config)
|
||||
default_session_config=context.context().config,
|
||||
cluster_device_filters=cluster_device_filters)
|
||||
|
||||
if context.get_server_def() is None:
|
||||
context.set_server_def(server_def)
|
||||
|
@ -290,13 +290,10 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase):
|
||||
def setUp(self):
|
||||
super(MultiJobsTest, self).setUp()
|
||||
|
||||
workers, ps = test_util.create_local_cluster(2, 1)
|
||||
workers, ps = test_util.create_local_cluster(num_workers=2, num_ps=2)
|
||||
cluster = {
|
||||
'my_worker': [
|
||||
_strip_prefix(workers[0].target, _GRPC_PREFIX),
|
||||
_strip_prefix(workers[1].target, _GRPC_PREFIX),
|
||||
],
|
||||
'my_ps': [_strip_prefix(ps[0].target, _GRPC_PREFIX)],
|
||||
'my_worker': [_strip_prefix(t.target, _GRPC_PREFIX) for t in workers],
|
||||
'my_ps': [_strip_prefix(t.target, _GRPC_PREFIX) for t in ps],
|
||||
}
|
||||
self._cluster = server_lib.ClusterSpec(cluster)
|
||||
self._cluster_resolver = SimpleClusterResolver(
|
||||
@ -330,6 +327,53 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.device('/job:my_worker/task:1/device:CPU:0'):
|
||||
self.assertAllEqual(worker_fn(), 8)
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testSimpleParameterServerWithDeviceFilters(self):
|
||||
cluster_device_filters = server_lib.ClusterDeviceFilters()
|
||||
for i in range(2):
|
||||
cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps'])
|
||||
cluster_device_filters.set_device_filters('my_ps', i, ['/job:my_worker'])
|
||||
remote.connect_to_cluster(
|
||||
self._cluster, cluster_device_filters=cluster_device_filters)
|
||||
|
||||
with ops.device('/job:my_ps/task:0/device:CPU:0'):
|
||||
v1 = variables.Variable(initial_value=0)
|
||||
with ops.device('/job:my_ps/task:1/device:CPU:0'):
|
||||
v2 = variables.Variable(initial_value=10)
|
||||
|
||||
@def_function.function
|
||||
def worker_fn():
|
||||
v1.assign_add(1)
|
||||
v2.assign_sub(2)
|
||||
return v1.read_value() + v2.read_value()
|
||||
|
||||
with ops.device('/job:my_worker/task:0/device:CPU:0'):
|
||||
self.assertAllEqual(worker_fn(), 9)
|
||||
with ops.device('/job:my_worker/task:1/device:CPU:0'):
|
||||
self.assertAllEqual(worker_fn(), 8)
|
||||
|
||||
# The following remote call would fail because the ps nodes cannot see each
|
||||
# other due to the device filters.
|
||||
with self.assertRaises(errors.InvalidArgumentError) as cm:
|
||||
with ops.device('/job:my_ps/task:0/device:CPU:0'):
|
||||
worker_fn().numpy()
|
||||
self.assertIn('/job:my_ps/replica:0/task:1/device:CPU:0 unknown device',
|
||||
cm.exception.message)
|
||||
|
||||
with self.assertRaises(errors.InvalidArgumentError) as cm:
|
||||
with ops.device('/job:my_ps/task:1/device:CPU:0'):
|
||||
worker_fn().numpy()
|
||||
self.assertIn('/job:my_ps/replica:0/task:0/device:CPU:0 unknown device',
|
||||
cm.exception.message)
|
||||
|
||||
with ops.device('/job:my_worker/task:0/device:CPU:0'):
|
||||
self.assertAllEqual(worker_fn(), 7)
|
||||
with ops.device('/job:my_worker/task:1/device:CPU:0'):
|
||||
self.assertAllEqual(worker_fn(), 6)
|
||||
# Explicitly delete variables to avoid triggering errors when being GC'ed in
|
||||
# subsequent tests.
|
||||
del v1, v2
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testConnectWithClusterResolver(self):
|
||||
remote.connect_to_cluster(self._cluster_resolver)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import device_filters_pb2
|
||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.framework import errors
|
||||
@ -488,3 +489,85 @@ class ClusterSpec(object):
|
||||
raise TypeError("Task address %r must be bytes or unicode" %
|
||||
task_address)
|
||||
job_def.tasks[i] = task_address
|
||||
|
||||
|
||||
@tf_export("config.experimental.ClusterDeviceFilters")
|
||||
class ClusterDeviceFilters(object):
|
||||
"""Represent a collection of device filters for the remote workers in cluster.
|
||||
|
||||
NOTE: this is an experimental API and subject to changes.
|
||||
|
||||
Set device filters for selective jobs and tasks. For each remote worker, the
|
||||
device filters are a list of strings. When any filters are present, the remote
|
||||
worker will ignore all devices which do not match any of its filters. Each
|
||||
filter can be partially specified, e.g. "/job:ps", "/job:worker/replica:3",
|
||||
etc. Note that a device is always visible to the worker it is located on.
|
||||
|
||||
For example, to set the device filters for a parameter server cluster:
|
||||
|
||||
```python
|
||||
cdf = tf.config.experimental.ClusterDeviceFilters()
|
||||
for i in range(num_workers):
|
||||
cdf.set_device_filters('worker', i, ['/job:ps'])
|
||||
for i in range(num_ps):
|
||||
cdf.set_device_filters('ps', i, ['/job:worker'])
|
||||
|
||||
tf.config.experimental_connect_to_cluster(cluster_def,
|
||||
cluster_device_filters=cdf)
|
||||
```
|
||||
|
||||
The device filters can be partically specified. For remote tasks that do not
|
||||
have device filters specified, all devices will be visible to them.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# `_device_filters` is a dict mapping job names to job device filters.
|
||||
# Job device filters further maps task IDs to task device filters.
|
||||
# Task device filters are a list of strings, each one is a device filter.
|
||||
self._device_filters = {}
|
||||
|
||||
# Serialized protobuf for cluster device filters.
|
||||
self._cluster_device_filters = None
|
||||
|
||||
def set_device_filters(self, job_name, task_index, device_filters):
|
||||
"""Set the device filters for given job name and task id."""
|
||||
assert all(isinstance(df, str) for df in device_filters)
|
||||
self._device_filters.setdefault(job_name, {})
|
||||
self._device_filters[job_name][task_index] = [df for df in device_filters]
|
||||
# Due to updates in data, invalidate the serialized proto cache.
|
||||
self._cluster_device_filters = None
|
||||
|
||||
def _as_cluster_device_filters(self):
|
||||
"""Returns a serialized protobuf of cluster device filters."""
|
||||
if self._cluster_device_filters:
|
||||
return self._cluster_device_filters
|
||||
|
||||
self._make_cluster_device_filters()
|
||||
return self._cluster_device_filters
|
||||
|
||||
def _make_cluster_device_filters(self):
|
||||
"""Creates `ClusterDeviceFilters` proto based on the `_device_filters`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `_device_filters` is not a dictionary mapping strings to
|
||||
a map of task indices and device filters.
|
||||
"""
|
||||
self._cluster_device_filters = device_filters_pb2.ClusterDeviceFilters()
|
||||
|
||||
# Sort by job_name to produce deterministic protobufs.
|
||||
for job_name, tasks in sorted(self._device_filters.items()):
|
||||
try:
|
||||
job_name = compat.as_bytes(job_name)
|
||||
except TypeError:
|
||||
raise TypeError("Job name %r must be bytes or unicode" % job_name)
|
||||
|
||||
jdf = self._cluster_device_filters.jobs.add()
|
||||
jdf.name = job_name
|
||||
|
||||
for i, task_device_filters in sorted(tasks.items()):
|
||||
for tdf in task_device_filters:
|
||||
try:
|
||||
tdf = compat.as_bytes(tdf)
|
||||
except TypeError:
|
||||
raise TypeError("Device filter %r must be bytes or unicode" % tdf)
|
||||
jdf.tasks[i].device_filters.append(tdf)
|
||||
|
@ -0,0 +1,13 @@
|
||||
path: "tensorflow.config.experimental.ClusterDeviceFilters"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.server_lib.ClusterDeviceFilters\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_device_filters"
|
||||
argspec: "args=[\'self\', \'job_name\', \'task_index\', \'device_filters\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.config.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "ClusterDeviceFilters"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "VirtualDeviceConfiguration"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -26,7 +26,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_cluster"
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], "
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\', \'cluster_device_filters\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_host"
|
||||
|
@ -40,5 +40,12 @@ tf_proto {
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_INT32
|
||||
}
|
||||
field {
|
||||
name: "cluster_device_filters"
|
||||
number: 7
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.ClusterDeviceFilters"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,13 @@
|
||||
path: "tensorflow.config.experimental.ClusterDeviceFilters"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.server_lib.ClusterDeviceFilters\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_device_filters"
|
||||
argspec: "args=[\'self\', \'job_name\', \'task_index\', \'device_filters\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.config.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "ClusterDeviceFilters"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "VirtualDeviceConfiguration"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -26,7 +26,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_cluster"
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], "
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\', \'cluster_device_filters\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_host"
|
||||
|
@ -40,5 +40,12 @@ tf_proto {
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_INT32
|
||||
}
|
||||
field {
|
||||
name: "cluster_device_filters"
|
||||
number: 7
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.ClusterDeviceFilters"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user