From 9da4a05369fd4fcdbe347f3163e5ed5878f5c454 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Wed, 2 Sep 2020 13:16:06 -0700 Subject: [PATCH] Exchange device attributes at group resolution again Previously CollectiveParamResolver queries device attributes when initializing instance params. That has issues when the collective leader fails and restarts quickly between group resolution and instance resolution. In such case, all other workers get the incarnation of the restarted leader, thus they're unable to detect that the leader has failed; the leader will deadlock on the group resolution. This change doesn't fully fixed the issue because it only exchanges device attributes at group resolution, but doesn't populate the device attributes to DeviceResolver. That will be done in a following change. This change also changes the behavior when a non-leader fails and restarts. Previously it gets the cached group resolution from the leader, now it will get an error because its incarnation doesn't match with the one in the cached group parameters. This should have no actual effect since that worker will always restart again after the leader has restarted. This change changes both the client and server without being backward compatible. It assumes that client and server are running the same version of Tensorflow. This should be true since the only way to use CollectiveParamResolverDistributed is through MultiWorkerMirroredStrategy (MWMS). For MWMS, all workers should run the same version of the program. PiperOrigin-RevId: 329774332 Change-Id: I6f3ba535cefd3a8ec321e4138436b6a085d64463 --- tensorflow/core/common_runtime/BUILD | 1 - .../base_collective_executor.cc | 4 +- .../common_runtime/base_collective_executor.h | 2 +- .../collective_param_resolver_local.cc | 96 +++++--- .../collective_param_resolver_local.h | 11 +- .../collective_param_resolver_local_test.cc | 76 +++--- .../test_collective_executor_mgr.h | 3 +- tensorflow/core/distributed_runtime/BUILD | 3 + .../collective_param_resolver_distributed.cc | 152 ++++++------ .../collective_param_resolver_distributed.h | 12 +- ...lective_param_resolver_distributed_test.cc | 233 ++++++++++-------- tensorflow/core/framework/collective.h | 6 +- tensorflow/core/kernels/collective_ops.cc | 5 +- tensorflow/core/protobuf/worker.proto | 9 +- 14 files changed, 340 insertions(+), 273 deletions(-) diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 73c1458eab4..313866c7d5d 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -379,7 +379,6 @@ cc_library( hdrs = ["collective_param_resolver_local.h"], copts = tf_copts(), deps = [ - ":device", ":device_mgr", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index a6629286698..6e5e5c82c42 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -298,8 +298,8 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, } void BaseCollectiveExecutor::CompleteParamsAsync( - const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, - StatusCallback done) { + const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, StatusCallback done) { cp->instance.gpu_ring_order = *gpu_ring_order_; const auto is_callback_called = std::make_shared>(false); auto done_with_timeout = done; diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h index c9cea393378..4081b887add 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.h +++ b/tensorflow/core/common_runtime/base_collective_executor.h @@ -113,7 +113,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor { void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params, const string& exec_key, StatusCallback done) override; - void CompleteParamsAsync(const string& device, CollectiveParams* cp, + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, CancellationManager* cancel_mgr, StatusCallback done) override; diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index ba21abcbaa8..cae7dd56621 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include +#include #include #include "tensorflow/core/common_runtime/device_mgr.h" @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" @@ -74,12 +75,21 @@ const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) { return "undef"; } } + +string TaskNameFromDeviceName(const string& device_name) { + DeviceNameUtils::ParsedName parsed_device; + CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device)); + string task_name; + CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name)); + return task_name; +} } // namespace void CollectiveParamResolverLocal::CompleteGroupLocal( - const string& device, CollectiveParams* cp, const GroupRecCallback& done) { - VLOG(1) << "CompleteGroupLocal device=" << device << " cp: " << cp << ": " - << cp->ToString(); + const DeviceAttributes& device, CollectiveParams* cp, + const GroupRecCallback& done) { + VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp + << ": " << cp->ToString(); std::vector to_be_called; GroupRec* gr = nullptr; Status status; @@ -139,13 +149,13 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( // status. VLOG(2) << "gr device_type=" << gr->group.device_type << " cp device_type=" << cp->group.device_type - << " current device=" << device; + << " current device=" << device.name(); if (gr->status.ok()) { // Check for consistency with existing GroupRec. if (cp->group.device_type != gr->group.device_type) { gr->status = errors::Internal( - "Collective Op ", cp->name, " is assigned to device ", device, - " with type ", cp->group.device_type.type_string(), + "Collective Op ", cp->name, " is assigned to device ", + device.name(), " with type ", cp->group.device_type.type_string(), " and group_key ", cp->group.group_key, " but that group has type ", gr->group.device_type.type_string()); } else if (cp->group.group_size != gr->group.group_size) { @@ -157,38 +167,47 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( } if (gr->status.ok()) { // Insert device if not already present. - auto it = gr->device_set.find(device); - if (it == gr->device_set.end()) { - if (gr->device_set.size() == gr->group.group_size) { + auto it = gr->devices.find(device.name()); + if (it == gr->devices.end()) { + if (gr->devices.size() == gr->group.group_size) { // The group is already full. gr->status = errors::Internal( - "Collective Op ", cp->name, " is assigned to device ", device, - " and group_key ", cp->group.group_key, + "Collective Op ", cp->name, " is assigned to device ", + device.name(), " and group_key ", cp->group.group_key, " but that group doesn't contain that device."); } else { // This is a new device that has not yet joined the group. - gr->device_set.insert(device); - gr->device_list.push_back(device); - DeviceNameUtils::ParsedName parsed_device; - DeviceNameUtils::ParseFullName(device, &parsed_device); - string task_name = strings::StrCat("/job:", parsed_device.job, - "/replica:", parsed_device.replica, - "/task:", parsed_device.task); - gr->task_set.insert(task_name); - gr->task_list.push_back(task_name); - gr->group.num_tasks = static_cast(gr->task_set.size()); + gr->devices[device.name()] = device; + if (gr->devices.size() == gr->group.group_size) { + // The group is full after adding this device, calculate the number + // of tasks. + std::unordered_set tasks; + for (const auto& item : gr->devices) { + tasks.insert(TaskNameFromDeviceName(item.first)); + } + gr->group.num_tasks = static_cast(tasks.size()); + } if (VLOG_IS_ON(1)) { string dev_buf; - for (const auto& d : gr->device_set) { - strings::StrAppend(&dev_buf, ",", d); + for (const auto& d : gr->devices) { + strings::StrAppend(&dev_buf, ",", d.first); } VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key << " group_size=" << gr->group.group_size << " (current" << " devices)=(" << dev_buf << ") (number of" << " devices pending)=" - << (gr->group.group_size - gr->device_set.size()); + << (gr->group.group_size - gr->devices.size()); } } + } else { + // If the device already exists, check if the incarnation matches. + if (it->second.incarnation() != device.incarnation()) { + gr->status = errors::FailedPrecondition( + "Device ", device.name(), + " current incarnation doesn't match with one in the group. This " + "usually means this worker has restarted but the collective " + "leader hasn't, or this worker connects to a wrong cluster."); + } } } @@ -196,13 +215,13 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( cp->group.runtime_details = gr->group.runtime_details; // If the group is not yet complete, queue to wait for it. VLOG(2) << "group_size " << gr->group.group_size << " set size " - << gr->device_set.size() << " gr " << gr; + << gr->devices.size() << " gr " << gr; - if (gr->device_set.size() < gr->group.group_size) { + if (gr->devices.size() < gr->group.group_size) { gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr)); return; } - CHECK_EQ(gr->device_set.size(), gr->group.group_size); + CHECK_EQ(gr->devices.size(), gr->group.group_size); } // At this point, we either have a full group, or an error status. Ensure // that all callbacks are invoked with the appropriate status. @@ -481,10 +500,15 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams( { mutex_lock gl(gr->mu); ir->shared.group = gr->group; - ir->shared.instance.device_names.assign(gr->device_list.begin(), - gr->device_list.end()); - ir->shared.instance.task_names.assign(gr->task_list.begin(), - gr->task_list.end()); + ir->shared.instance.device_names.clear(); + ir->shared.instance.task_names.clear(); + ir->shared.instance.device_names.reserve(gr->devices.size()); + ir->shared.instance.task_names.reserve(gr->devices.size()); + for (const auto& item : gr->devices) { + ir->shared.instance.device_names.push_back(item.first); + ir->shared.instance.task_names.push_back( + TaskNameFromDeviceName(item.first)); + } VLOG(2) << "Initialized names for instance: " << ir->shared.instance.ToString(); } @@ -682,15 +706,15 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams( } void CollectiveParamResolverLocal::CompleteParamsAsync( - const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, - const StatusCallback& done) { - VLOG(1) << "CompleteParams local " << device << " for " << cp << ": " + const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, const StatusCallback& done) { + VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": " << cp->ToString(); CompleteGroupLocal( device, cp, [this, device, cp, done](const Status& s, const GroupRec* gr) { if (s.ok()) { - CompleteInstanceLocal(device, gr, cp, cp->is_source, done); + CompleteInstanceLocal(device.name(), gr, cp, cp->is_source, done); } else { done(s); } diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index 40f0f00affc..51cfa893d00 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -19,9 +19,11 @@ limitations under the License. #include #include #include +#include #include #include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -45,7 +47,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { ~CollectiveParamResolverLocal() override {} - void CompleteParamsAsync(const string& device, CollectiveParams* cp, + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, CancellationManager* cancel_mgr, const StatusCallback& done) override; @@ -70,10 +72,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { CollGroupParams group; mutable mutex mu; Status status TF_GUARDED_BY(mu); - std::set device_set TF_GUARDED_BY(mu); - std::vector device_list TF_GUARDED_BY(mu); - std::set task_set TF_GUARDED_BY(mu); - std::vector task_list TF_GUARDED_BY(mu); + std::unordered_map devices TF_GUARDED_BY(mu); std::vector waiting TF_GUARDED_BY(mu); }; @@ -85,7 +84,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { // callback. typedef std::function GroupRecCallback; - void CompleteGroupLocal(const string& device, CollectiveParams* cp, + void CompleteGroupLocal(const DeviceAttributes& device, CollectiveParams* cp, const GroupRecCallback& done) TF_LOCKS_EXCLUDED(group_mu_); diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc index f23f03dc406..b117632dbd2 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_resolver_local.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -86,6 +87,12 @@ class CollectiveParamResolverLocalTest : public ::testing::Test { } } + DeviceAttributes GetDeviceAttributes(const string& device_name) { + Device* device = nullptr; + TF_CHECK_OK(device_mgr_->LookupDevice(device_name, &device)); + return device->attributes(); + } + string task_name_; std::unique_ptr device_mgr_; std::unique_ptr drl_; @@ -187,12 +194,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) { cp->instance.impl_details.subdiv_offsets.push_back(0); cp->is_source = false; Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() { - prl_->CompleteParamsAsync(cp->instance.device_names[0], cp, - nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { - statuses[i] = s; - note[i].Notify(); - }); + prl_->CompleteParamsAsync( + GetDeviceAttributes(cp->instance.device_names[0]), cp, + nullptr /*CancellationManager*/, + [&statuses, ¬e, i](const Status& s) { + statuses[i] = s; + note[i].Notify(); + }); }); } for (int i = 0; i < NUM_DEVS; ++i) { @@ -240,12 +248,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) { CollectiveParams* cp = &cps[i]; InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp); Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() { - prl_->CompleteParamsAsync(cp->instance.device_names[0], cp, - nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { - statuses[i] = s; - note[i].Notify(); - }); + prl_->CompleteParamsAsync( + GetDeviceAttributes(cp->instance.device_names[0]), cp, + nullptr /*CancellationManager*/, + [&statuses, ¬e, i](const Status& s) { + statuses[i] = s; + note[i].Notify(); + }); }); } for (int i = 0; i < NUM_DEVS; ++i) { @@ -278,12 +287,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) { CollectiveParams* cp = &cps[i]; InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp); Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() { - prl_->CompleteParamsAsync(cp->instance.device_names[0], cp, - nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { - statuses[i] = s; - note[i].Notify(); - }); + prl_->CompleteParamsAsync( + GetDeviceAttributes(cp->instance.device_names[0]), cp, + nullptr /*CancellationManager*/, + [&statuses, ¬e, i](const Status& s) { + statuses[i] = s; + note[i].Notify(); + }); }); } for (int i = 0; i < NUM_DEVS; ++i) { @@ -326,8 +336,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr, - [&done](const Status& s) { + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], + &cancel_mgr, [&done](const Status& s) { EXPECT_EQ(s.code(), error::ABORTED); EXPECT_EQ(s.error_message(), "__aborted__"); done.DecrementCount(); @@ -355,8 +365,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); cp[i] = MakeCollectiveParams(group_key, instance_key, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr, - [&done](const Status& s) { + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], + &cancel_mgr, [&done](const Status& s) { EXPECT_EQ(s.code(), error::OK); done.DecrementCount(); }); @@ -373,12 +383,13 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); cp[i] = MakeCollectiveParams(group_key, instance_key + 1, /*is_source*/ i == 0); - prl_->CompleteParamsAsync( - device, &cp[i], &cancel_mgr, [&done](const Status& s) { - EXPECT_EQ(s.code(), error::ABORTED); - EXPECT_EQ(s.error_message(), "__aborted__"); - done.DecrementCount(); - }); + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], + &cancel_mgr, [&done](const Status& s) { + EXPECT_EQ(s.code(), error::ABORTED); + EXPECT_EQ(s.error_message(), + "__aborted__"); + done.DecrementCount(); + }); start.DecrementCount(); }); } @@ -402,8 +413,8 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); cp[i] = MakeCollectiveParams(group_key, instance_key, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr, - [&done](const Status& s) { + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], + &cancel_mgr, [&done](const Status& s) { EXPECT_EQ(s.code(), error::OK); done.DecrementCount(); }); @@ -418,7 +429,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) { Notification done; auto cp = MakeCollectiveParams(group_key, instance_key, /*is_source*/ true); - prl_->CompleteParamsAsync(device, &cp, &cancel_mgr, + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, &cancel_mgr, [&done](const Status& s) { EXPECT_EQ(s.code(), error::ABORTED); EXPECT_EQ(s.error_message(), "__aborted__"); @@ -457,7 +468,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortNormalCompleteParamsAsync) { auto cp = MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(device, &cp, &cancel_mgr, + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, + &cancel_mgr, [&status, &n](const Status& s) { status = s; n.Notify(); diff --git a/tensorflow/core/common_runtime/test_collective_executor_mgr.h b/tensorflow/core/common_runtime/test_collective_executor_mgr.h index c2e6d2ae08c..ff4d966d76c 100644 --- a/tensorflow/core/common_runtime/test_collective_executor_mgr.h +++ b/tensorflow/core/common_runtime/test_collective_executor_mgr.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ #include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { @@ -35,7 +36,7 @@ class TestCollectiveExecutor : public CollectiveExecutor { }; class TestParamResolver : public ParamResolverInterface { - void CompleteParamsAsync(const string& device, CollectiveParams* cp, + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, CancellationManager* cancel_mgr, const StatusCallback& done) override { done(errors::Internal("Unimplemented")); diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 94570c1b577..6e87d28781f 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -571,7 +571,9 @@ cc_library( ":device_resolver_distributed", ":worker_cache", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", "@com_google_absl//absl/strings", ], ) @@ -606,6 +608,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/kernels:collective_ops", + "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc index 650c52cd8da..44d90818f9b 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc @@ -18,14 +18,18 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/cancellable_call.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { namespace { class CompleteGroupCall : public CancellableCall { public: - CompleteGroupCall(const CollGroupParams& group, const string& device_name, + CompleteGroupCall(const CollGroupParams& group, + const DeviceAttributes& device, const CollectiveType& collective_type, CancellationManager* cancel_mgr, const string& remote_worker, WorkerCacheInterface* wc) @@ -33,7 +37,7 @@ class CompleteGroupCall : public CancellableCall { req_.set_group_key(group.group_key); req_.set_group_size(group.group_size); req_.set_device_type(group.device_type.type_string()); - req_.add_device_name(device_name); + *req_.mutable_device_attributes() = device; req_.set_collective_type(collective_type); } ~CompleteGroupCall() override {} @@ -98,16 +102,16 @@ CollectiveParamResolverDistributed::CollectiveParamResolverDistributed( } void CollectiveParamResolverDistributed::CompleteParamsAsync( - const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, - const StatusCallback& done) { - VLOG(1) << "CompleteParams distributed " << device << " for " << cp << ": " - << cp->ToString(); + const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, const StatusCallback& done) { + VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp + << ": " << cp->ToString(); CompleteGroupDistributed(device, cp, cancel_mgr, [this, device, cp, cancel_mgr, done]( const Status& s, const GroupRec* gr) { if (s.ok()) { - CompleteInstanceDistributed(device, gr, cp, - cancel_mgr, done); + CompleteInstanceDistributed( + device.name(), gr, cp, cancel_mgr, done); } else { done(s); } @@ -117,28 +121,28 @@ void CollectiveParamResolverDistributed::CompleteParamsAsync( void CollectiveParamResolverDistributed::CompleteGroupAsync( const CompleteGroupRequest* request, CompleteGroupResponse* response, CancellationManager* cancel_mgr, const StatusCallback& done) { + if (!request->has_device_attributes()) { + done(errors::Internal( + "CompleteGroupRequest device_attributes is not set. Make sure you're " + "running the same version of Tensorflow on all workers.")); + return; + } CollectiveParams cp; cp.group.group_key = request->group_key(); cp.group.group_size = request->group_size(); cp.group.device_type = DeviceType(request->device_type()); - for (const string& dn : request->device_name()) { - cp.instance.device_names.push_back(dn); - } cp.instance.type = CollectiveType(request->collective_type()); CompleteGroupDistributed( - cp.instance.device_names[0], &cp, cancel_mgr, + request->device_attributes(), &cp, cancel_mgr, [response, done](const Status& s, const GroupRec* gr) { if (s.ok()) { mutex_lock l(gr->mu); response->set_group_key(gr->group.group_key); response->set_group_size(gr->group.group_size); response->set_device_type(gr->group.device_type.type_string()); - response->set_num_tasks(gr->task_set.size()); - for (const string& dn : gr->device_list) { - response->add_device_name(dn); - } - for (const string& tn : gr->task_list) { - response->add_task_name(tn); + response->set_num_tasks(gr->group.num_tasks); + for (const auto& item : gr->devices) { + *response->add_device_attributes() = item.second; } response->set_communicator_key( gr->group.runtime_details.communicator_key); @@ -152,6 +156,22 @@ void CollectiveParamResolverDistributed::CompleteGroupAsync( void CollectiveParamResolverDistributed::CompleteInstanceAsync( const CompleteInstanceRequest* request, CompleteInstanceResponse* response, CancellationManager* cancel_mgr, const StatusCallback& done) { + GroupRec* gr = GetCachedGroup(request->group_key()); + if (gr == nullptr) { + done(errors::FailedPrecondition( + "group ", request->group_key(), + " not found. This normally means the server has restarted")); + return; + } + { + mutex_lock l(gr->mu); + if (!gr->status.ok() || gr->devices.size() != gr->group.group_size) { + done(errors::FailedPrecondition( + "group ", request->group_key(), + " failed to resolve. This normally means the server has restarted")); + return; + } + } CollectiveParams* cp = new CollectiveParams; cp->name = request->name(); cp->group.group_key = request->group_key(); @@ -164,56 +184,44 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync( for (int32 offset : request->subdiv_offset()) { cp->instance.impl_details.subdiv_offsets.push_back(offset); } - string* device = new string(request->device()); - VLOG(1) << "New cp " << cp << " for device " << *device << " : " - << cp->ToString(); - StatusCallback done_and_cleanup = [cp, device, done](const Status& s) { + StatusCallback done_and_cleanup = [cp, done](const Status& s) { done(s); delete cp; - delete device; }; - // Start by completing the group. - CompleteGroupDistributed( - *device, cp, cancel_mgr, - [this, cp, device, response, cancel_mgr, done_and_cleanup]( - const Status& cg_status, const GroupRec* gr) { - if (cg_status.ok()) { - // Then complete the instance. - CompleteInstanceDistributed( - *device, gr, cp, cancel_mgr, - [this, gr, cp, response, - done_and_cleanup](const Status& ci_status) { - if (ci_status.ok()) { - // Now source_rank should be known, so - // retrieve it. - FindInstanceRec( - gr, cp, - [cp, response, done_and_cleanup](const Status& fi_status, - InstanceRec* ir) { - if (fi_status.ok()) { - mutex_lock l(ir->out_mu); - ir->WaitForOutMu(l); - response->set_instance_key(cp->instance.instance_key); - response->set_source_rank(ir->source_rank); - done_and_cleanup(fi_status); - } else { - done_and_cleanup(fi_status); - } - }); + CompleteInstanceDistributed( + request->device(), gr, cp, cancel_mgr, + [this, gr, cp, response, done_and_cleanup](const Status& ci_status) { + if (ci_status.ok()) { + // Now source_rank should be known, so + // retrieve it. + FindInstanceRec( + gr, cp, + [cp, response, done_and_cleanup](const Status& fi_status, + InstanceRec* ir) { + if (fi_status.ok()) { + mutex_lock l(ir->out_mu); + ir->WaitForOutMu(l); + response->set_instance_key(cp->instance.instance_key); + response->set_source_rank(ir->source_rank); + done_and_cleanup(fi_status); } else { - done_and_cleanup(ci_status); + done_and_cleanup(fi_status); } }); } else { - done_and_cleanup(cg_status); + done_and_cleanup(ci_status); } }); } -bool CollectiveParamResolverDistributed::GroupIsCached(int32 group_key) { +CollectiveParamResolverDistributed::GroupRec* +CollectiveParamResolverDistributed::GetCachedGroup(int32 group_key) { mutex_lock l(group_mu_); - const auto& it = group_table_.find(group_key); - return it != group_table_.end(); + auto it = group_table_.find(group_key); + if (it == group_table_.end()) { + return nullptr; + } + return it->second.get(); } Status CollectiveParamResolverDistributed::UpdateGroupCache( @@ -226,26 +234,19 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache( gr->group.group_key = resp.group_key(); gr->group.group_size = resp.group_size(); gr->group.num_tasks = resp.num_tasks(); - if (resp.device_name_size() != gr->group.group_size) { + if (resp.device_attributes().empty()) { + return errors::Internal( + "CompleteGroupResponse device_attributes is empty. Make sure you're " + "running the same version of Tensorflow on all workers."); + } + if (resp.device_attributes_size() != gr->group.group_size) { return errors::Internal( "CompleteGroupResponse group_size doesn't match device_name list"); } - for (const string& dn : resp.device_name()) { - gr->device_set.insert(dn); - gr->device_list.push_back(dn); + for (const DeviceAttributes& device : resp.device_attributes()) { + gr->devices[device.name()] = device; } - if (resp.task_name_size() != gr->group.group_size) { - return errors::Internal( - "CompleteGroupResponse group_size doesn't match task_name list"); - } - for (const string& tn : resp.task_name()) { - gr->task_list.push_back(tn); - gr->task_set.insert(tn); - } - CHECK_EQ(gr->task_set.size(), gr->group.num_tasks); gr->group.runtime_details.communicator_key = resp.communicator_key(); - VLOG(2) << "Group communicator_key=" - << absl::CEscape(gr->group.runtime_details.communicator_key); } { // Group membership should never change. Once a record is in group_table_ @@ -273,14 +274,15 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache( } void CollectiveParamResolverDistributed::CompleteGroupDistributed( - const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, - const GroupRecCallback& done) { + const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, const GroupRecCallback& done) { VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key - << " dev: " << device << " is_leader=" << (group_leader_.empty()); + << " dev: " << device.name() + << " is_leader=" << (group_leader_.empty()); if (group_leader_.empty()) { // This is the group leader, so resolution is local. return CompleteGroupLocal(device, cp, done); - } else if (!GroupIsCached(cp->group.group_key)) { + } else if (GetCachedGroup(cp->group.group_key) == nullptr) { // Need to update Group cache from the leader. CompleteGroupCall* call = new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr, diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h index 684887430c3..fc692a12fc6 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_ #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" +#include "tensorflow/core/framework/device_attributes.pb.h" namespace tensorflow { class ConfigProto; @@ -31,7 +32,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { WorkerCacheInterface* worker_cache, const string& task_name); - void CompleteParamsAsync(const string& device, CollectiveParams* cp, + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, CancellationManager* cancel_mgr, const StatusCallback& done) override; @@ -46,9 +47,9 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { const StatusCallback& done) override; protected: - // Returns true iff there's an entry for this group_key in the - // local group_table_. - bool GroupIsCached(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_); + // Returns the cached group iff there's an entry for this group_key in the + // local group_table_; returns nullptr otherwise. + GroupRec* GetCachedGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_); // Updates group_table_ with contents of resp. Status UpdateGroupCache(const CompleteGroupResponse& resp) @@ -59,7 +60,8 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { // // Semantics are like those of CompleteGroupLocal but will make a // remote call to the group leader if necessary. - void CompleteGroupDistributed(const string& device, CollectiveParams* cp, + void CompleteGroupDistributed(const DeviceAttributes& device, + CollectiveParams* cp, CancellationManager* cancel_mgr, const GroupRecCallback& done); diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index 130a48e80d2..a963c02d413 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/test_utils.h" @@ -23,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/device_name_utils.h" @@ -41,6 +43,7 @@ static std::unique_ptr NewDevice(const string& type, attr.set_name(name); attr.set_device_type(type); attr.mutable_locality()->set_numa_node(3); // a non-default value + attr.set_incarnation(random::New64()); return absl::make_unique(attr); } @@ -125,127 +128,110 @@ class FakeCache : public TestWorkerCache { class DeviceResDistTest : public ::testing::Test { protected: - DeviceResDistTest() {} - - ~DeviceResDistTest() override { - for (DeviceMgr* dm : device_mgrs_) { - delete dm; - } - for (auto it : dev_resolvers_) { - delete it.second; - } - for (auto it : cp_resolvers_) { - delete it.second; - } - for (FakeWorker* w : workers_) { - delete w; - } - } - void DefineWorkers(int num_workers, int num_devices, const string& device_type, bool nccl) { - ConfigProto config; for (int w = 0; w < num_workers; ++w) { string name = strings::StrCat("/job:worker/replica:0/task:", w); - if (w == 0) { - config.mutable_experimental()->set_collective_group_leader(name); - if (nccl) { - config.mutable_experimental()->set_collective_nccl(true); - } - } - DefineWorker(config, name, device_type, num_devices); + DefineWorker(name, device_type, num_devices, nccl); } } - void DefineWorker(const ConfigProto& config, const string& worker_name, - const string& device_type, int num_devices) { + void DefineWorker(const string& worker_name, const string& device_type, + int num_devices, bool nccl) { + ConfigProto config; + config.mutable_experimental()->set_collective_group_leader( + "/job:worker/replica:0/task:0"); + config.mutable_experimental()->set_collective_nccl(nccl); + std::vector> devices; for (int i = 0; i < num_devices; ++i) { devices.push_back(NewDevice( device_type, strings::StrCat(worker_name, "/device:", device_type, ":", i))); } - DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices)); - device_mgrs_.push_back(dev_mgr); + device_mgrs_[worker_name] = + absl::make_unique(std::move(devices)); std::vector* dv = &dev_by_task_[worker_name]; - for (auto* d : dev_mgr->ListDevices()) { + dv->clear(); + for (auto* d : device_mgrs_[worker_name]->ListDevices()) { dv->push_back(d->name()); } - DeviceResolverDistributed* dev_res = - new DeviceResolverDistributed(dev_mgr, &wc_, worker_name); - dev_resolvers_[worker_name] = dev_res; - CollectiveParamResolverDistributed* cp_res = - new CollectiveParamResolverDistributed(config, dev_mgr, dev_res, &wc_, - worker_name); - cp_resolvers_[worker_name] = cp_res; - FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, cp_res); - workers_.push_back(fw); - wc_.AddWorker(worker_name, fw); + dev_resolvers_[worker_name] = absl::make_unique( + device_mgrs_[worker_name].get(), &wc_, worker_name); + cp_resolvers_[worker_name] = + absl::make_unique( + config, device_mgrs_[worker_name].get(), + dev_resolvers_[worker_name].get(), &wc_, worker_name); + workers_[worker_name] = absl::make_unique( + worker_name, device_mgrs_[worker_name].get(), + cp_resolvers_[worker_name].get()); + wc_.AddWorker(worker_name, workers_[worker_name].get()); } - void DefineCollectiveParams(int num_workers, int num_devices) { - const int kGroupKey = 5; - const int kInstanceKey = 3; + void DefineCollectiveParams(int num_workers, int num_devices, + const string& device_type) { for (int wi = 0; wi < num_workers; ++wi) { string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); for (int di = 0; di < num_devices; ++di) { - string device_name = strings::StrCat(task_name, "/device:CPU:", di); - cp_.push_back(CollectiveParams()); - CollectiveParams& cp = cp_.back(); - cp.group.group_key = kGroupKey; - cp.group.group_size = num_workers * num_devices; - cp.group.device_type = DEVICE_CPU; - cp.group.num_tasks = num_workers; - cp.instance.instance_key = kInstanceKey; - cp.instance.type = REDUCTION_COLLECTIVE; - cp.instance.data_type = DT_FLOAT; - cp.instance.shape = TensorShape({64}); - cp.instance.impl_details.subdiv_offsets.push_back(0); + string device_name = + strings::StrCat(task_name, "/device:", device_type, ":", di); + cp_[device_name] = + CreateCollectiveParams(num_workers, num_devices, device_type); } } } + CollectiveParams CreateCollectiveParams(int num_workers, int num_devices, + const string& device_type) { + const int kGroupKey = 5; + const int kInstanceKey = 3; + CollectiveParams cp; + cp.group.group_key = kGroupKey; + cp.group.group_size = num_workers * num_devices; + cp.group.device_type = DeviceType(device_type); + cp.group.num_tasks = num_workers; + cp.instance.instance_key = kInstanceKey; + cp.instance.type = REDUCTION_COLLECTIVE; + cp.instance.data_type = DT_FLOAT; + cp.instance.shape = TensorShape({64}); + cp.instance.impl_details.subdiv_offsets.push_back(0); + return cp; + } + void IssueRequests(int num_workers, int num_devices) { - const int device_count = num_workers * num_devices; { mutex_lock l(mu_); num_done_ = 0; } - cp_.resize(device_count); - status_.resize(device_count); - int idx = 0; + int group_size = num_workers * num_devices; for (int wi = 0; wi < num_workers; ++wi) { + string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); for (int di = 0; di < num_devices; ++di) { - IssueRequest(num_workers, num_devices, idx); - ++idx; + string device_name = strings::StrCat(task_name, "/device:CPU:", di); + IssueRequest(task_name, device_name, group_size); } } } - void IssueRequest(int num_workers, int num_devices, int idx) { - int device_count = num_workers * num_devices; - int wi = idx / num_devices; - int di = idx % num_devices; - string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); - string device_name = strings::StrCat(task_name, "/device:CPU:", di); - while (idx >= cp_.size()) { - status_.resize(idx + 1); - cp_.resize(idx + 1); - } - CollectiveParams* cp = &cp_[idx]; - CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name]; + void IssueRequest(const string& task_name, const string& device_name, + int group_size) { + Device* device = nullptr; + TF_CHECK_OK(device_mgrs_[task_name]->LookupDevice(device_name, &device)); + CollectiveParams* cp = &cp_[device_name]; + CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name].get(); CHECK(cp_res); - cp_res->CompleteParamsAsync(device_name, cp, &cm_, - [this, idx, device_count](const Status& s) { - status_[idx] = s; - { - mutex_lock l(mu_); - ++num_done_; - if (num_done_ == device_count) { - done_.notify_all(); - } - } - }); + cp_res->CompleteParamsAsync( + device->attributes(), cp, &cm_, + [this, device_name, group_size](const Status& s) { + status_[device_name] = s; + { + mutex_lock l(mu_); + ++num_done_; + if (num_done_ == group_size) { + done_.notify_all(); + } + } + }); } void ValidateCollectiveParams(int num_workers, int num_devices) { @@ -259,39 +245,59 @@ class DeviceResDistTest : public ::testing::Test { // Verify that all cp_ values get the same set of task and device // names, with unique default_rank in the expected order. const int dev_count = num_workers * num_devices; + string dev0 = "/job:worker/replica:0/task:0/device:CPU:0"; for (int wi = 0; wi < num_workers; ++wi) { string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); for (int di = 0; di < num_devices; ++di) { string device_name = strings::StrCat(task_name, "/device:CPU:", di); int idx = wi * num_devices + di; - TF_ASSERT_OK(status_[idx]); - EXPECT_EQ(cp_[idx].default_rank, idx); - EXPECT_EQ(cp_[idx].instance.device_names.size(), dev_count); - EXPECT_EQ(cp_[idx].instance.device_names[idx], device_name); - EXPECT_EQ(cp_[idx].instance.task_names[idx], task_name); + TF_ASSERT_OK(status_[device_name]); + EXPECT_EQ(cp_[device_name].default_rank, idx); + EXPECT_EQ(cp_[device_name].instance.device_names.size(), dev_count); + EXPECT_EQ(cp_[device_name].instance.device_names[idx], device_name); + EXPECT_EQ(cp_[device_name].instance.task_names[idx], task_name); if (idx > 0) { - EXPECT_EQ(cp_[0].group.runtime_details.communicator_key, - cp_[idx].group.runtime_details.communicator_key); + EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key, + cp_[device_name].group.runtime_details.communicator_key); for (int i = 0; i < dev_count; ++i) { - EXPECT_EQ(cp_[0].instance.device_names[i], - cp_[idx].instance.device_names[i]); - EXPECT_EQ(cp_[0].instance.task_names[i], - cp_[idx].instance.task_names[i]); + EXPECT_EQ(cp_[dev0].instance.device_names[i], + cp_[device_name].instance.device_names[i]); + EXPECT_EQ(cp_[dev0].instance.task_names[i], + cp_[device_name].instance.task_names[i]); } } } } } + void RestartWorker(int worker_idx, int num_workers, int num_devices, + const string& device_type, bool nccl) { + string worker_name = + strings::StrCat("/job:worker/replica:0/task:", worker_idx); + DefineWorker(worker_name, device_type, num_devices, nccl); + for (int i = 0; i < num_devices; ++i) { + string device_name = + strings::StrCat(worker_name, "/device:", device_type, ":", i); + cp_[device_name] = + CreateCollectiveParams(num_workers, num_devices, device_type); + status_.erase(device_name); + } + } + FakeCache wc_; CancellationManager cm_; - std::vector device_mgrs_; - std::unordered_map dev_resolvers_; - std::unordered_map cp_resolvers_; - std::unordered_map> dev_by_task_; - std::vector workers_; - std::vector cp_; - std::vector status_; + // Below are keyed by task names. + absl::flat_hash_map> device_mgrs_; + absl::flat_hash_map> + dev_resolvers_; + absl::flat_hash_map> + cp_resolvers_; + absl::flat_hash_map> dev_by_task_; + absl::flat_hash_map> workers_; + // Below are keyed by device names; + absl::flat_hash_map cp_; + absl::flat_hash_map status_; mutex mu_; int num_done_ TF_GUARDED_BY(mu_); condition_variable done_; @@ -300,8 +306,8 @@ class DeviceResDistTest : public ::testing::Test { TEST_F(DeviceResDistTest, Workers1Devices1) { const int num_workers = 1; const int num_devices = 1; - DefineWorkers(num_workers, num_devices, "CPU", false); - DefineCollectiveParams(num_workers, num_devices); + DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false); + DefineCollectiveParams(num_workers, num_devices, "CPU"); IssueRequests(num_workers, num_devices); ValidateCollectiveParams(num_workers, num_devices); } @@ -309,12 +315,25 @@ TEST_F(DeviceResDistTest, Workers1Devices1) { TEST_F(DeviceResDistTest, Workers2Devices2) { const int num_workers = 2; const int num_devices = 2; - DefineWorkers(num_workers, num_devices, "CPU", false); - DefineCollectiveParams(num_workers, num_devices); + DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false); + DefineCollectiveParams(num_workers, num_devices, "CPU"); IssueRequests(num_workers, num_devices); ValidateCollectiveParams(num_workers, num_devices); } +TEST_F(DeviceResDistTest, DifferentIncarnation) { + const int num_workers = 2; + const int num_devices = 1; + DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false); + DefineCollectiveParams(num_workers, num_devices, "CPU"); + IssueRequests(num_workers, num_devices); + RestartWorker(1, num_workers, num_devices, "CPU", /*nccl*/ false); + const string task_name = "/job:worker/replica:0/task:1"; + const string device_name = absl::StrCat(task_name, "/device:CPU:0"); + IssueRequest(task_name, device_name, num_workers * num_devices); + EXPECT_TRUE(errors::IsFailedPrecondition(status_[device_name])); +} + #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM namespace { // A mock NcclReducer for testing group runtime details initialization with CPU @@ -347,7 +366,7 @@ TEST_F(DeviceResDistTest, Workers4Devices3) { const int num_workers = 4; const int num_devices = 3; DefineWorkers(num_workers, num_devices, "CPU", true); - DefineCollectiveParams(num_workers, num_devices); + DefineCollectiveParams(num_workers, num_devices, "CPU"); IssueRequests(num_workers, num_devices); ValidateCollectiveParams(num_workers, num_devices); } diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index 72e0b3d9224..0f7a7ffba96 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -180,7 +180,8 @@ class ParamResolverInterface { // Called by each collective op at first execution in order to fill out // the CollectiveParams structure with data gathered from the full // (maybe distributed) collection of peer nodes. - virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp, + virtual void CompleteParamsAsync(const DeviceAttributes& device, + CollectiveParams* cp, CancellationManager* cancel_mgr, const StatusCallback& done) = 0; @@ -301,7 +302,8 @@ class CollectiveExecutor : public core::RefCounted { "a CollectiveExecutor has not been provided.")); } - virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp, + virtual void CompleteParamsAsync(const DeviceAttributes& device, + CollectiveParams* cp, CancellationManager* cancel_mgr, StatusCallback done) { done(errors::Internal( diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc index 0230852d082..522c0967316 100644 --- a/tensorflow/core/kernels/collective_ops.cc +++ b/tensorflow/core/kernels/collective_ops.cc @@ -73,7 +73,7 @@ class CollectiveOpKernel : public AsyncOpKernel { << " group " << col_params_.group.group_key << " instance " << col_params_.instance.instance_key; col_exec->CompleteParamsAsync( - c->device()->name(), &col_params_, c->cancellation_manager(), + c->device()->attributes(), &col_params_, c->cancellation_manager(), [this, c, done](const Status& s) { if (s.ok()) { col_params_.instance.impl_details.dependencies = dependencies_; @@ -538,7 +538,8 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel { << " group " << col_params->group.group_key << " instance " << col_params->instance.instance_key; col_exec->CompleteParamsAsync( - c->device()->name(), col_params.get(), c->cancellation_manager(), + c->device()->attributes(), col_params.get(), + c->cancellation_manager(), [c, done = std::move(done), col_params, col_exec](const Status& s) { if (s.ok()) { auto actual_done = [c, group_key = col_params->group.group_key, diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 739ba8e03e6..0b4b50236c4 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -545,8 +545,10 @@ message CompleteGroupRequest { int32 group_key = 1; int32 group_size = 2; string device_type = 3; - repeated string device_name = 4; int32 collective_type = 5; + DeviceAttributes device_attributes = 6; + + reserved 4; } // Gives the complete membership of the group identified by group_key. @@ -555,9 +557,10 @@ message CompleteGroupResponse { int32 group_size = 2; string device_type = 3; int32 num_tasks = 4; // number of distinct tasks hosting the devices - repeated string device_name = 5; - repeated string task_name = 6; // task name prefixes of device_names bytes communicator_key = 7; + repeated DeviceAttributes device_attributes = 8; + + reserved 5, 6; } // Supplies data about one collective op belonging to the instance identified