diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 73c1458eab4..313866c7d5d 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -379,7 +379,6 @@ cc_library( hdrs = ["collective_param_resolver_local.h"], copts = tf_copts(), deps = [ - ":device", ":device_mgr", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index a6629286698..6e5e5c82c42 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -298,8 +298,8 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, } void BaseCollectiveExecutor::CompleteParamsAsync( - const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, - StatusCallback done) { + const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, StatusCallback done) { cp->instance.gpu_ring_order = *gpu_ring_order_; const auto is_callback_called = std::make_shared>(false); auto done_with_timeout = done; diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h index c9cea393378..4081b887add 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.h +++ b/tensorflow/core/common_runtime/base_collective_executor.h @@ -113,7 +113,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor { void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params, const string& exec_key, StatusCallback done) override; - void CompleteParamsAsync(const string& device, CollectiveParams* cp, + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, CancellationManager* cancel_mgr, StatusCallback done) override; diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index ba21abcbaa8..cae7dd56621 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include +#include #include #include "tensorflow/core/common_runtime/device_mgr.h" @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" @@ -74,12 +75,21 @@ const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) { return "undef"; } } + +string TaskNameFromDeviceName(const string& device_name) { + DeviceNameUtils::ParsedName parsed_device; + CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device)); + string task_name; + CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name)); + return task_name; +} } // namespace void CollectiveParamResolverLocal::CompleteGroupLocal( - const string& device, CollectiveParams* cp, const GroupRecCallback& done) { - VLOG(1) << "CompleteGroupLocal device=" << device << " cp: " << cp << ": " - << cp->ToString(); + const DeviceAttributes& device, CollectiveParams* cp, + const GroupRecCallback& done) { + VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp + << ": " << cp->ToString(); std::vector to_be_called; GroupRec* gr = nullptr; Status status; @@ -139,13 +149,13 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( // status. VLOG(2) << "gr device_type=" << gr->group.device_type << " cp device_type=" << cp->group.device_type - << " current device=" << device; + << " current device=" << device.name(); if (gr->status.ok()) { // Check for consistency with existing GroupRec. if (cp->group.device_type != gr->group.device_type) { gr->status = errors::Internal( - "Collective Op ", cp->name, " is assigned to device ", device, - " with type ", cp->group.device_type.type_string(), + "Collective Op ", cp->name, " is assigned to device ", + device.name(), " with type ", cp->group.device_type.type_string(), " and group_key ", cp->group.group_key, " but that group has type ", gr->group.device_type.type_string()); } else if (cp->group.group_size != gr->group.group_size) { @@ -157,38 +167,47 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( } if (gr->status.ok()) { // Insert device if not already present. - auto it = gr->device_set.find(device); - if (it == gr->device_set.end()) { - if (gr->device_set.size() == gr->group.group_size) { + auto it = gr->devices.find(device.name()); + if (it == gr->devices.end()) { + if (gr->devices.size() == gr->group.group_size) { // The group is already full. gr->status = errors::Internal( - "Collective Op ", cp->name, " is assigned to device ", device, - " and group_key ", cp->group.group_key, + "Collective Op ", cp->name, " is assigned to device ", + device.name(), " and group_key ", cp->group.group_key, " but that group doesn't contain that device."); } else { // This is a new device that has not yet joined the group. - gr->device_set.insert(device); - gr->device_list.push_back(device); - DeviceNameUtils::ParsedName parsed_device; - DeviceNameUtils::ParseFullName(device, &parsed_device); - string task_name = strings::StrCat("/job:", parsed_device.job, - "/replica:", parsed_device.replica, - "/task:", parsed_device.task); - gr->task_set.insert(task_name); - gr->task_list.push_back(task_name); - gr->group.num_tasks = static_cast(gr->task_set.size()); + gr->devices[device.name()] = device; + if (gr->devices.size() == gr->group.group_size) { + // The group is full after adding this device, calculate the number + // of tasks. + std::unordered_set tasks; + for (const auto& item : gr->devices) { + tasks.insert(TaskNameFromDeviceName(item.first)); + } + gr->group.num_tasks = static_cast(tasks.size()); + } if (VLOG_IS_ON(1)) { string dev_buf; - for (const auto& d : gr->device_set) { - strings::StrAppend(&dev_buf, ",", d); + for (const auto& d : gr->devices) { + strings::StrAppend(&dev_buf, ",", d.first); } VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key << " group_size=" << gr->group.group_size << " (current" << " devices)=(" << dev_buf << ") (number of" << " devices pending)=" - << (gr->group.group_size - gr->device_set.size()); + << (gr->group.group_size - gr->devices.size()); } } + } else { + // If the device already exists, check if the incarnation matches. + if (it->second.incarnation() != device.incarnation()) { + gr->status = errors::FailedPrecondition( + "Device ", device.name(), + " current incarnation doesn't match with one in the group. This " + "usually means this worker has restarted but the collective " + "leader hasn't, or this worker connects to a wrong cluster."); + } } } @@ -196,13 +215,13 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( cp->group.runtime_details = gr->group.runtime_details; // If the group is not yet complete, queue to wait for it. VLOG(2) << "group_size " << gr->group.group_size << " set size " - << gr->device_set.size() << " gr " << gr; + << gr->devices.size() << " gr " << gr; - if (gr->device_set.size() < gr->group.group_size) { + if (gr->devices.size() < gr->group.group_size) { gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr)); return; } - CHECK_EQ(gr->device_set.size(), gr->group.group_size); + CHECK_EQ(gr->devices.size(), gr->group.group_size); } // At this point, we either have a full group, or an error status. Ensure // that all callbacks are invoked with the appropriate status. @@ -481,10 +500,15 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams( { mutex_lock gl(gr->mu); ir->shared.group = gr->group; - ir->shared.instance.device_names.assign(gr->device_list.begin(), - gr->device_list.end()); - ir->shared.instance.task_names.assign(gr->task_list.begin(), - gr->task_list.end()); + ir->shared.instance.device_names.clear(); + ir->shared.instance.task_names.clear(); + ir->shared.instance.device_names.reserve(gr->devices.size()); + ir->shared.instance.task_names.reserve(gr->devices.size()); + for (const auto& item : gr->devices) { + ir->shared.instance.device_names.push_back(item.first); + ir->shared.instance.task_names.push_back( + TaskNameFromDeviceName(item.first)); + } VLOG(2) << "Initialized names for instance: " << ir->shared.instance.ToString(); } @@ -682,15 +706,15 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams( } void CollectiveParamResolverLocal::CompleteParamsAsync( - const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, - const StatusCallback& done) { - VLOG(1) << "CompleteParams local " << device << " for " << cp << ": " + const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, const StatusCallback& done) { + VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": " << cp->ToString(); CompleteGroupLocal( device, cp, [this, device, cp, done](const Status& s, const GroupRec* gr) { if (s.ok()) { - CompleteInstanceLocal(device, gr, cp, cp->is_source, done); + CompleteInstanceLocal(device.name(), gr, cp, cp->is_source, done); } else { done(s); } diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index 40f0f00affc..51cfa893d00 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -19,9 +19,11 @@ limitations under the License. #include #include #include +#include #include #include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -45,7 +47,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { ~CollectiveParamResolverLocal() override {} - void CompleteParamsAsync(const string& device, CollectiveParams* cp, + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, CancellationManager* cancel_mgr, const StatusCallback& done) override; @@ -70,10 +72,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { CollGroupParams group; mutable mutex mu; Status status TF_GUARDED_BY(mu); - std::set device_set TF_GUARDED_BY(mu); - std::vector device_list TF_GUARDED_BY(mu); - std::set task_set TF_GUARDED_BY(mu); - std::vector task_list TF_GUARDED_BY(mu); + std::unordered_map devices TF_GUARDED_BY(mu); std::vector waiting TF_GUARDED_BY(mu); }; @@ -85,7 +84,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { // callback. typedef std::function GroupRecCallback; - void CompleteGroupLocal(const string& device, CollectiveParams* cp, + void CompleteGroupLocal(const DeviceAttributes& device, CollectiveParams* cp, const GroupRecCallback& done) TF_LOCKS_EXCLUDED(group_mu_); diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc index f23f03dc406..b117632dbd2 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_resolver_local.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -86,6 +87,12 @@ class CollectiveParamResolverLocalTest : public ::testing::Test { } } + DeviceAttributes GetDeviceAttributes(const string& device_name) { + Device* device = nullptr; + TF_CHECK_OK(device_mgr_->LookupDevice(device_name, &device)); + return device->attributes(); + } + string task_name_; std::unique_ptr device_mgr_; std::unique_ptr drl_; @@ -187,12 +194,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) { cp->instance.impl_details.subdiv_offsets.push_back(0); cp->is_source = false; Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() { - prl_->CompleteParamsAsync(cp->instance.device_names[0], cp, - nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { - statuses[i] = s; - note[i].Notify(); - }); + prl_->CompleteParamsAsync( + GetDeviceAttributes(cp->instance.device_names[0]), cp, + nullptr /*CancellationManager*/, + [&statuses, ¬e, i](const Status& s) { + statuses[i] = s; + note[i].Notify(); + }); }); } for (int i = 0; i < NUM_DEVS; ++i) { @@ -240,12 +248,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) { CollectiveParams* cp = &cps[i]; InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp); Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() { - prl_->CompleteParamsAsync(cp->instance.device_names[0], cp, - nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { - statuses[i] = s; - note[i].Notify(); - }); + prl_->CompleteParamsAsync( + GetDeviceAttributes(cp->instance.device_names[0]), cp, + nullptr /*CancellationManager*/, + [&statuses, ¬e, i](const Status& s) { + statuses[i] = s; + note[i].Notify(); + }); }); } for (int i = 0; i < NUM_DEVS; ++i) { @@ -278,12 +287,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) { CollectiveParams* cp = &cps[i]; InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp); Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() { - prl_->CompleteParamsAsync(cp->instance.device_names[0], cp, - nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { - statuses[i] = s; - note[i].Notify(); - }); + prl_->CompleteParamsAsync( + GetDeviceAttributes(cp->instance.device_names[0]), cp, + nullptr /*CancellationManager*/, + [&statuses, ¬e, i](const Status& s) { + statuses[i] = s; + note[i].Notify(); + }); }); } for (int i = 0; i < NUM_DEVS; ++i) { @@ -326,8 +336,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr, - [&done](const Status& s) { + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], + &cancel_mgr, [&done](const Status& s) { EXPECT_EQ(s.code(), error::ABORTED); EXPECT_EQ(s.error_message(), "__aborted__"); done.DecrementCount(); @@ -355,8 +365,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); cp[i] = MakeCollectiveParams(group_key, instance_key, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr, - [&done](const Status& s) { + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], + &cancel_mgr, [&done](const Status& s) { EXPECT_EQ(s.code(), error::OK); done.DecrementCount(); }); @@ -373,12 +383,13 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); cp[i] = MakeCollectiveParams(group_key, instance_key + 1, /*is_source*/ i == 0); - prl_->CompleteParamsAsync( - device, &cp[i], &cancel_mgr, [&done](const Status& s) { - EXPECT_EQ(s.code(), error::ABORTED); - EXPECT_EQ(s.error_message(), "__aborted__"); - done.DecrementCount(); - }); + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], + &cancel_mgr, [&done](const Status& s) { + EXPECT_EQ(s.code(), error::ABORTED); + EXPECT_EQ(s.error_message(), + "__aborted__"); + done.DecrementCount(); + }); start.DecrementCount(); }); } @@ -402,8 +413,8 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); cp[i] = MakeCollectiveParams(group_key, instance_key, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr, - [&done](const Status& s) { + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], + &cancel_mgr, [&done](const Status& s) { EXPECT_EQ(s.code(), error::OK); done.DecrementCount(); }); @@ -418,7 +429,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) { Notification done; auto cp = MakeCollectiveParams(group_key, instance_key, /*is_source*/ true); - prl_->CompleteParamsAsync(device, &cp, &cancel_mgr, + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, &cancel_mgr, [&done](const Status& s) { EXPECT_EQ(s.code(), error::ABORTED); EXPECT_EQ(s.error_message(), "__aborted__"); @@ -457,7 +468,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortNormalCompleteParamsAsync) { auto cp = MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(device, &cp, &cancel_mgr, + prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, + &cancel_mgr, [&status, &n](const Status& s) { status = s; n.Notify(); diff --git a/tensorflow/core/common_runtime/test_collective_executor_mgr.h b/tensorflow/core/common_runtime/test_collective_executor_mgr.h index c2e6d2ae08c..ff4d966d76c 100644 --- a/tensorflow/core/common_runtime/test_collective_executor_mgr.h +++ b/tensorflow/core/common_runtime/test_collective_executor_mgr.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ #include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { @@ -35,7 +36,7 @@ class TestCollectiveExecutor : public CollectiveExecutor { }; class TestParamResolver : public ParamResolverInterface { - void CompleteParamsAsync(const string& device, CollectiveParams* cp, + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, CancellationManager* cancel_mgr, const StatusCallback& done) override { done(errors::Internal("Unimplemented")); diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 94570c1b577..6e87d28781f 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -571,7 +571,9 @@ cc_library( ":device_resolver_distributed", ":worker_cache", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", "@com_google_absl//absl/strings", ], ) @@ -606,6 +608,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/kernels:collective_ops", + "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc index 650c52cd8da..44d90818f9b 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc @@ -18,14 +18,18 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/cancellable_call.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { namespace { class CompleteGroupCall : public CancellableCall { public: - CompleteGroupCall(const CollGroupParams& group, const string& device_name, + CompleteGroupCall(const CollGroupParams& group, + const DeviceAttributes& device, const CollectiveType& collective_type, CancellationManager* cancel_mgr, const string& remote_worker, WorkerCacheInterface* wc) @@ -33,7 +37,7 @@ class CompleteGroupCall : public CancellableCall { req_.set_group_key(group.group_key); req_.set_group_size(group.group_size); req_.set_device_type(group.device_type.type_string()); - req_.add_device_name(device_name); + *req_.mutable_device_attributes() = device; req_.set_collective_type(collective_type); } ~CompleteGroupCall() override {} @@ -98,16 +102,16 @@ CollectiveParamResolverDistributed::CollectiveParamResolverDistributed( } void CollectiveParamResolverDistributed::CompleteParamsAsync( - const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, - const StatusCallback& done) { - VLOG(1) << "CompleteParams distributed " << device << " for " << cp << ": " - << cp->ToString(); + const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, const StatusCallback& done) { + VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp + << ": " << cp->ToString(); CompleteGroupDistributed(device, cp, cancel_mgr, [this, device, cp, cancel_mgr, done]( const Status& s, const GroupRec* gr) { if (s.ok()) { - CompleteInstanceDistributed(device, gr, cp, - cancel_mgr, done); + CompleteInstanceDistributed( + device.name(), gr, cp, cancel_mgr, done); } else { done(s); } @@ -117,28 +121,28 @@ void CollectiveParamResolverDistributed::CompleteParamsAsync( void CollectiveParamResolverDistributed::CompleteGroupAsync( const CompleteGroupRequest* request, CompleteGroupResponse* response, CancellationManager* cancel_mgr, const StatusCallback& done) { + if (!request->has_device_attributes()) { + done(errors::Internal( + "CompleteGroupRequest device_attributes is not set. Make sure you're " + "running the same version of Tensorflow on all workers.")); + return; + } CollectiveParams cp; cp.group.group_key = request->group_key(); cp.group.group_size = request->group_size(); cp.group.device_type = DeviceType(request->device_type()); - for (const string& dn : request->device_name()) { - cp.instance.device_names.push_back(dn); - } cp.instance.type = CollectiveType(request->collective_type()); CompleteGroupDistributed( - cp.instance.device_names[0], &cp, cancel_mgr, + request->device_attributes(), &cp, cancel_mgr, [response, done](const Status& s, const GroupRec* gr) { if (s.ok()) { mutex_lock l(gr->mu); response->set_group_key(gr->group.group_key); response->set_group_size(gr->group.group_size); response->set_device_type(gr->group.device_type.type_string()); - response->set_num_tasks(gr->task_set.size()); - for (const string& dn : gr->device_list) { - response->add_device_name(dn); - } - for (const string& tn : gr->task_list) { - response->add_task_name(tn); + response->set_num_tasks(gr->group.num_tasks); + for (const auto& item : gr->devices) { + *response->add_device_attributes() = item.second; } response->set_communicator_key( gr->group.runtime_details.communicator_key); @@ -152,6 +156,22 @@ void CollectiveParamResolverDistributed::CompleteGroupAsync( void CollectiveParamResolverDistributed::CompleteInstanceAsync( const CompleteInstanceRequest* request, CompleteInstanceResponse* response, CancellationManager* cancel_mgr, const StatusCallback& done) { + GroupRec* gr = GetCachedGroup(request->group_key()); + if (gr == nullptr) { + done(errors::FailedPrecondition( + "group ", request->group_key(), + " not found. This normally means the server has restarted")); + return; + } + { + mutex_lock l(gr->mu); + if (!gr->status.ok() || gr->devices.size() != gr->group.group_size) { + done(errors::FailedPrecondition( + "group ", request->group_key(), + " failed to resolve. This normally means the server has restarted")); + return; + } + } CollectiveParams* cp = new CollectiveParams; cp->name = request->name(); cp->group.group_key = request->group_key(); @@ -164,56 +184,44 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync( for (int32 offset : request->subdiv_offset()) { cp->instance.impl_details.subdiv_offsets.push_back(offset); } - string* device = new string(request->device()); - VLOG(1) << "New cp " << cp << " for device " << *device << " : " - << cp->ToString(); - StatusCallback done_and_cleanup = [cp, device, done](const Status& s) { + StatusCallback done_and_cleanup = [cp, done](const Status& s) { done(s); delete cp; - delete device; }; - // Start by completing the group. - CompleteGroupDistributed( - *device, cp, cancel_mgr, - [this, cp, device, response, cancel_mgr, done_and_cleanup]( - const Status& cg_status, const GroupRec* gr) { - if (cg_status.ok()) { - // Then complete the instance. - CompleteInstanceDistributed( - *device, gr, cp, cancel_mgr, - [this, gr, cp, response, - done_and_cleanup](const Status& ci_status) { - if (ci_status.ok()) { - // Now source_rank should be known, so - // retrieve it. - FindInstanceRec( - gr, cp, - [cp, response, done_and_cleanup](const Status& fi_status, - InstanceRec* ir) { - if (fi_status.ok()) { - mutex_lock l(ir->out_mu); - ir->WaitForOutMu(l); - response->set_instance_key(cp->instance.instance_key); - response->set_source_rank(ir->source_rank); - done_and_cleanup(fi_status); - } else { - done_and_cleanup(fi_status); - } - }); + CompleteInstanceDistributed( + request->device(), gr, cp, cancel_mgr, + [this, gr, cp, response, done_and_cleanup](const Status& ci_status) { + if (ci_status.ok()) { + // Now source_rank should be known, so + // retrieve it. + FindInstanceRec( + gr, cp, + [cp, response, done_and_cleanup](const Status& fi_status, + InstanceRec* ir) { + if (fi_status.ok()) { + mutex_lock l(ir->out_mu); + ir->WaitForOutMu(l); + response->set_instance_key(cp->instance.instance_key); + response->set_source_rank(ir->source_rank); + done_and_cleanup(fi_status); } else { - done_and_cleanup(ci_status); + done_and_cleanup(fi_status); } }); } else { - done_and_cleanup(cg_status); + done_and_cleanup(ci_status); } }); } -bool CollectiveParamResolverDistributed::GroupIsCached(int32 group_key) { +CollectiveParamResolverDistributed::GroupRec* +CollectiveParamResolverDistributed::GetCachedGroup(int32 group_key) { mutex_lock l(group_mu_); - const auto& it = group_table_.find(group_key); - return it != group_table_.end(); + auto it = group_table_.find(group_key); + if (it == group_table_.end()) { + return nullptr; + } + return it->second.get(); } Status CollectiveParamResolverDistributed::UpdateGroupCache( @@ -226,26 +234,19 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache( gr->group.group_key = resp.group_key(); gr->group.group_size = resp.group_size(); gr->group.num_tasks = resp.num_tasks(); - if (resp.device_name_size() != gr->group.group_size) { + if (resp.device_attributes().empty()) { + return errors::Internal( + "CompleteGroupResponse device_attributes is empty. Make sure you're " + "running the same version of Tensorflow on all workers."); + } + if (resp.device_attributes_size() != gr->group.group_size) { return errors::Internal( "CompleteGroupResponse group_size doesn't match device_name list"); } - for (const string& dn : resp.device_name()) { - gr->device_set.insert(dn); - gr->device_list.push_back(dn); + for (const DeviceAttributes& device : resp.device_attributes()) { + gr->devices[device.name()] = device; } - if (resp.task_name_size() != gr->group.group_size) { - return errors::Internal( - "CompleteGroupResponse group_size doesn't match task_name list"); - } - for (const string& tn : resp.task_name()) { - gr->task_list.push_back(tn); - gr->task_set.insert(tn); - } - CHECK_EQ(gr->task_set.size(), gr->group.num_tasks); gr->group.runtime_details.communicator_key = resp.communicator_key(); - VLOG(2) << "Group communicator_key=" - << absl::CEscape(gr->group.runtime_details.communicator_key); } { // Group membership should never change. Once a record is in group_table_ @@ -273,14 +274,15 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache( } void CollectiveParamResolverDistributed::CompleteGroupDistributed( - const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, - const GroupRecCallback& done) { + const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, const GroupRecCallback& done) { VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key - << " dev: " << device << " is_leader=" << (group_leader_.empty()); + << " dev: " << device.name() + << " is_leader=" << (group_leader_.empty()); if (group_leader_.empty()) { // This is the group leader, so resolution is local. return CompleteGroupLocal(device, cp, done); - } else if (!GroupIsCached(cp->group.group_key)) { + } else if (GetCachedGroup(cp->group.group_key) == nullptr) { // Need to update Group cache from the leader. CompleteGroupCall* call = new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr, diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h index 684887430c3..fc692a12fc6 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_ #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" +#include "tensorflow/core/framework/device_attributes.pb.h" namespace tensorflow { class ConfigProto; @@ -31,7 +32,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { WorkerCacheInterface* worker_cache, const string& task_name); - void CompleteParamsAsync(const string& device, CollectiveParams* cp, + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, CancellationManager* cancel_mgr, const StatusCallback& done) override; @@ -46,9 +47,9 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { const StatusCallback& done) override; protected: - // Returns true iff there's an entry for this group_key in the - // local group_table_. - bool GroupIsCached(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_); + // Returns the cached group iff there's an entry for this group_key in the + // local group_table_; returns nullptr otherwise. + GroupRec* GetCachedGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_); // Updates group_table_ with contents of resp. Status UpdateGroupCache(const CompleteGroupResponse& resp) @@ -59,7 +60,8 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { // // Semantics are like those of CompleteGroupLocal but will make a // remote call to the group leader if necessary. - void CompleteGroupDistributed(const string& device, CollectiveParams* cp, + void CompleteGroupDistributed(const DeviceAttributes& device, + CollectiveParams* cp, CancellationManager* cancel_mgr, const GroupRecCallback& done); diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index 130a48e80d2..a963c02d413 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/test_utils.h" @@ -23,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/device_name_utils.h" @@ -41,6 +43,7 @@ static std::unique_ptr NewDevice(const string& type, attr.set_name(name); attr.set_device_type(type); attr.mutable_locality()->set_numa_node(3); // a non-default value + attr.set_incarnation(random::New64()); return absl::make_unique(attr); } @@ -125,127 +128,110 @@ class FakeCache : public TestWorkerCache { class DeviceResDistTest : public ::testing::Test { protected: - DeviceResDistTest() {} - - ~DeviceResDistTest() override { - for (DeviceMgr* dm : device_mgrs_) { - delete dm; - } - for (auto it : dev_resolvers_) { - delete it.second; - } - for (auto it : cp_resolvers_) { - delete it.second; - } - for (FakeWorker* w : workers_) { - delete w; - } - } - void DefineWorkers(int num_workers, int num_devices, const string& device_type, bool nccl) { - ConfigProto config; for (int w = 0; w < num_workers; ++w) { string name = strings::StrCat("/job:worker/replica:0/task:", w); - if (w == 0) { - config.mutable_experimental()->set_collective_group_leader(name); - if (nccl) { - config.mutable_experimental()->set_collective_nccl(true); - } - } - DefineWorker(config, name, device_type, num_devices); + DefineWorker(name, device_type, num_devices, nccl); } } - void DefineWorker(const ConfigProto& config, const string& worker_name, - const string& device_type, int num_devices) { + void DefineWorker(const string& worker_name, const string& device_type, + int num_devices, bool nccl) { + ConfigProto config; + config.mutable_experimental()->set_collective_group_leader( + "/job:worker/replica:0/task:0"); + config.mutable_experimental()->set_collective_nccl(nccl); + std::vector> devices; for (int i = 0; i < num_devices; ++i) { devices.push_back(NewDevice( device_type, strings::StrCat(worker_name, "/device:", device_type, ":", i))); } - DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices)); - device_mgrs_.push_back(dev_mgr); + device_mgrs_[worker_name] = + absl::make_unique(std::move(devices)); std::vector* dv = &dev_by_task_[worker_name]; - for (auto* d : dev_mgr->ListDevices()) { + dv->clear(); + for (auto* d : device_mgrs_[worker_name]->ListDevices()) { dv->push_back(d->name()); } - DeviceResolverDistributed* dev_res = - new DeviceResolverDistributed(dev_mgr, &wc_, worker_name); - dev_resolvers_[worker_name] = dev_res; - CollectiveParamResolverDistributed* cp_res = - new CollectiveParamResolverDistributed(config, dev_mgr, dev_res, &wc_, - worker_name); - cp_resolvers_[worker_name] = cp_res; - FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, cp_res); - workers_.push_back(fw); - wc_.AddWorker(worker_name, fw); + dev_resolvers_[worker_name] = absl::make_unique( + device_mgrs_[worker_name].get(), &wc_, worker_name); + cp_resolvers_[worker_name] = + absl::make_unique( + config, device_mgrs_[worker_name].get(), + dev_resolvers_[worker_name].get(), &wc_, worker_name); + workers_[worker_name] = absl::make_unique( + worker_name, device_mgrs_[worker_name].get(), + cp_resolvers_[worker_name].get()); + wc_.AddWorker(worker_name, workers_[worker_name].get()); } - void DefineCollectiveParams(int num_workers, int num_devices) { - const int kGroupKey = 5; - const int kInstanceKey = 3; + void DefineCollectiveParams(int num_workers, int num_devices, + const string& device_type) { for (int wi = 0; wi < num_workers; ++wi) { string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); for (int di = 0; di < num_devices; ++di) { - string device_name = strings::StrCat(task_name, "/device:CPU:", di); - cp_.push_back(CollectiveParams()); - CollectiveParams& cp = cp_.back(); - cp.group.group_key = kGroupKey; - cp.group.group_size = num_workers * num_devices; - cp.group.device_type = DEVICE_CPU; - cp.group.num_tasks = num_workers; - cp.instance.instance_key = kInstanceKey; - cp.instance.type = REDUCTION_COLLECTIVE; - cp.instance.data_type = DT_FLOAT; - cp.instance.shape = TensorShape({64}); - cp.instance.impl_details.subdiv_offsets.push_back(0); + string device_name = + strings::StrCat(task_name, "/device:", device_type, ":", di); + cp_[device_name] = + CreateCollectiveParams(num_workers, num_devices, device_type); } } } + CollectiveParams CreateCollectiveParams(int num_workers, int num_devices, + const string& device_type) { + const int kGroupKey = 5; + const int kInstanceKey = 3; + CollectiveParams cp; + cp.group.group_key = kGroupKey; + cp.group.group_size = num_workers * num_devices; + cp.group.device_type = DeviceType(device_type); + cp.group.num_tasks = num_workers; + cp.instance.instance_key = kInstanceKey; + cp.instance.type = REDUCTION_COLLECTIVE; + cp.instance.data_type = DT_FLOAT; + cp.instance.shape = TensorShape({64}); + cp.instance.impl_details.subdiv_offsets.push_back(0); + return cp; + } + void IssueRequests(int num_workers, int num_devices) { - const int device_count = num_workers * num_devices; { mutex_lock l(mu_); num_done_ = 0; } - cp_.resize(device_count); - status_.resize(device_count); - int idx = 0; + int group_size = num_workers * num_devices; for (int wi = 0; wi < num_workers; ++wi) { + string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); for (int di = 0; di < num_devices; ++di) { - IssueRequest(num_workers, num_devices, idx); - ++idx; + string device_name = strings::StrCat(task_name, "/device:CPU:", di); + IssueRequest(task_name, device_name, group_size); } } } - void IssueRequest(int num_workers, int num_devices, int idx) { - int device_count = num_workers * num_devices; - int wi = idx / num_devices; - int di = idx % num_devices; - string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); - string device_name = strings::StrCat(task_name, "/device:CPU:", di); - while (idx >= cp_.size()) { - status_.resize(idx + 1); - cp_.resize(idx + 1); - } - CollectiveParams* cp = &cp_[idx]; - CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name]; + void IssueRequest(const string& task_name, const string& device_name, + int group_size) { + Device* device = nullptr; + TF_CHECK_OK(device_mgrs_[task_name]->LookupDevice(device_name, &device)); + CollectiveParams* cp = &cp_[device_name]; + CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name].get(); CHECK(cp_res); - cp_res->CompleteParamsAsync(device_name, cp, &cm_, - [this, idx, device_count](const Status& s) { - status_[idx] = s; - { - mutex_lock l(mu_); - ++num_done_; - if (num_done_ == device_count) { - done_.notify_all(); - } - } - }); + cp_res->CompleteParamsAsync( + device->attributes(), cp, &cm_, + [this, device_name, group_size](const Status& s) { + status_[device_name] = s; + { + mutex_lock l(mu_); + ++num_done_; + if (num_done_ == group_size) { + done_.notify_all(); + } + } + }); } void ValidateCollectiveParams(int num_workers, int num_devices) { @@ -259,39 +245,59 @@ class DeviceResDistTest : public ::testing::Test { // Verify that all cp_ values get the same set of task and device // names, with unique default_rank in the expected order. const int dev_count = num_workers * num_devices; + string dev0 = "/job:worker/replica:0/task:0/device:CPU:0"; for (int wi = 0; wi < num_workers; ++wi) { string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); for (int di = 0; di < num_devices; ++di) { string device_name = strings::StrCat(task_name, "/device:CPU:", di); int idx = wi * num_devices + di; - TF_ASSERT_OK(status_[idx]); - EXPECT_EQ(cp_[idx].default_rank, idx); - EXPECT_EQ(cp_[idx].instance.device_names.size(), dev_count); - EXPECT_EQ(cp_[idx].instance.device_names[idx], device_name); - EXPECT_EQ(cp_[idx].instance.task_names[idx], task_name); + TF_ASSERT_OK(status_[device_name]); + EXPECT_EQ(cp_[device_name].default_rank, idx); + EXPECT_EQ(cp_[device_name].instance.device_names.size(), dev_count); + EXPECT_EQ(cp_[device_name].instance.device_names[idx], device_name); + EXPECT_EQ(cp_[device_name].instance.task_names[idx], task_name); if (idx > 0) { - EXPECT_EQ(cp_[0].group.runtime_details.communicator_key, - cp_[idx].group.runtime_details.communicator_key); + EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key, + cp_[device_name].group.runtime_details.communicator_key); for (int i = 0; i < dev_count; ++i) { - EXPECT_EQ(cp_[0].instance.device_names[i], - cp_[idx].instance.device_names[i]); - EXPECT_EQ(cp_[0].instance.task_names[i], - cp_[idx].instance.task_names[i]); + EXPECT_EQ(cp_[dev0].instance.device_names[i], + cp_[device_name].instance.device_names[i]); + EXPECT_EQ(cp_[dev0].instance.task_names[i], + cp_[device_name].instance.task_names[i]); } } } } } + void RestartWorker(int worker_idx, int num_workers, int num_devices, + const string& device_type, bool nccl) { + string worker_name = + strings::StrCat("/job:worker/replica:0/task:", worker_idx); + DefineWorker(worker_name, device_type, num_devices, nccl); + for (int i = 0; i < num_devices; ++i) { + string device_name = + strings::StrCat(worker_name, "/device:", device_type, ":", i); + cp_[device_name] = + CreateCollectiveParams(num_workers, num_devices, device_type); + status_.erase(device_name); + } + } + FakeCache wc_; CancellationManager cm_; - std::vector device_mgrs_; - std::unordered_map dev_resolvers_; - std::unordered_map cp_resolvers_; - std::unordered_map> dev_by_task_; - std::vector workers_; - std::vector cp_; - std::vector status_; + // Below are keyed by task names. + absl::flat_hash_map> device_mgrs_; + absl::flat_hash_map> + dev_resolvers_; + absl::flat_hash_map> + cp_resolvers_; + absl::flat_hash_map> dev_by_task_; + absl::flat_hash_map> workers_; + // Below are keyed by device names; + absl::flat_hash_map cp_; + absl::flat_hash_map status_; mutex mu_; int num_done_ TF_GUARDED_BY(mu_); condition_variable done_; @@ -300,8 +306,8 @@ class DeviceResDistTest : public ::testing::Test { TEST_F(DeviceResDistTest, Workers1Devices1) { const int num_workers = 1; const int num_devices = 1; - DefineWorkers(num_workers, num_devices, "CPU", false); - DefineCollectiveParams(num_workers, num_devices); + DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false); + DefineCollectiveParams(num_workers, num_devices, "CPU"); IssueRequests(num_workers, num_devices); ValidateCollectiveParams(num_workers, num_devices); } @@ -309,12 +315,25 @@ TEST_F(DeviceResDistTest, Workers1Devices1) { TEST_F(DeviceResDistTest, Workers2Devices2) { const int num_workers = 2; const int num_devices = 2; - DefineWorkers(num_workers, num_devices, "CPU", false); - DefineCollectiveParams(num_workers, num_devices); + DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false); + DefineCollectiveParams(num_workers, num_devices, "CPU"); IssueRequests(num_workers, num_devices); ValidateCollectiveParams(num_workers, num_devices); } +TEST_F(DeviceResDistTest, DifferentIncarnation) { + const int num_workers = 2; + const int num_devices = 1; + DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false); + DefineCollectiveParams(num_workers, num_devices, "CPU"); + IssueRequests(num_workers, num_devices); + RestartWorker(1, num_workers, num_devices, "CPU", /*nccl*/ false); + const string task_name = "/job:worker/replica:0/task:1"; + const string device_name = absl::StrCat(task_name, "/device:CPU:0"); + IssueRequest(task_name, device_name, num_workers * num_devices); + EXPECT_TRUE(errors::IsFailedPrecondition(status_[device_name])); +} + #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM namespace { // A mock NcclReducer for testing group runtime details initialization with CPU @@ -347,7 +366,7 @@ TEST_F(DeviceResDistTest, Workers4Devices3) { const int num_workers = 4; const int num_devices = 3; DefineWorkers(num_workers, num_devices, "CPU", true); - DefineCollectiveParams(num_workers, num_devices); + DefineCollectiveParams(num_workers, num_devices, "CPU"); IssueRequests(num_workers, num_devices); ValidateCollectiveParams(num_workers, num_devices); } diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index 72e0b3d9224..0f7a7ffba96 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -180,7 +180,8 @@ class ParamResolverInterface { // Called by each collective op at first execution in order to fill out // the CollectiveParams structure with data gathered from the full // (maybe distributed) collection of peer nodes. - virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp, + virtual void CompleteParamsAsync(const DeviceAttributes& device, + CollectiveParams* cp, CancellationManager* cancel_mgr, const StatusCallback& done) = 0; @@ -301,7 +302,8 @@ class CollectiveExecutor : public core::RefCounted { "a CollectiveExecutor has not been provided.")); } - virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp, + virtual void CompleteParamsAsync(const DeviceAttributes& device, + CollectiveParams* cp, CancellationManager* cancel_mgr, StatusCallback done) { done(errors::Internal( diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc index 0230852d082..522c0967316 100644 --- a/tensorflow/core/kernels/collective_ops.cc +++ b/tensorflow/core/kernels/collective_ops.cc @@ -73,7 +73,7 @@ class CollectiveOpKernel : public AsyncOpKernel { << " group " << col_params_.group.group_key << " instance " << col_params_.instance.instance_key; col_exec->CompleteParamsAsync( - c->device()->name(), &col_params_, c->cancellation_manager(), + c->device()->attributes(), &col_params_, c->cancellation_manager(), [this, c, done](const Status& s) { if (s.ok()) { col_params_.instance.impl_details.dependencies = dependencies_; @@ -538,7 +538,8 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel { << " group " << col_params->group.group_key << " instance " << col_params->instance.instance_key; col_exec->CompleteParamsAsync( - c->device()->name(), col_params.get(), c->cancellation_manager(), + c->device()->attributes(), col_params.get(), + c->cancellation_manager(), [c, done = std::move(done), col_params, col_exec](const Status& s) { if (s.ok()) { auto actual_done = [c, group_key = col_params->group.group_key, diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 739ba8e03e6..0b4b50236c4 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -545,8 +545,10 @@ message CompleteGroupRequest { int32 group_key = 1; int32 group_size = 2; string device_type = 3; - repeated string device_name = 4; int32 collective_type = 5; + DeviceAttributes device_attributes = 6; + + reserved 4; } // Gives the complete membership of the group identified by group_key. @@ -555,9 +557,10 @@ message CompleteGroupResponse { int32 group_size = 2; string device_type = 3; int32 num_tasks = 4; // number of distinct tasks hosting the devices - repeated string device_name = 5; - repeated string task_name = 6; // task name prefixes of device_names bytes communicator_key = 7; + repeated DeviceAttributes device_attributes = 8; + + reserved 5, 6; } // Supplies data about one collective op belonging to the instance identified