Exchange device attributes at group resolution again

Previously CollectiveParamResolver queries device attributes when initializing instance params. That has issues when the collective leader fails and restarts quickly between group resolution and instance resolution. In such case, all other workers get the incarnation of the restarted leader, thus they're unable to detect that the leader has failed; the leader will deadlock on the group resolution.

This change doesn't fully fixed the issue because it only exchanges device attributes at group resolution, but doesn't populate the device attributes to DeviceResolver. That will be done in a following change.

This change also changes the behavior when a non-leader fails and restarts. Previously it gets the cached group resolution from the leader, now it will get an error because its incarnation doesn't match with the one in the cached group parameters. This should have no actual effect since that worker will always restart again after the leader has restarted.

This change changes both the client and server without being backward compatible. It assumes that client and server are running the same version of Tensorflow. This should be true since the only way to use CollectiveParamResolverDistributed is through MultiWorkerMirroredStrategy (MWMS). For MWMS, all workers should run the same version of the program.

PiperOrigin-RevId: 329774332
Change-Id: I6f3ba535cefd3a8ec321e4138436b6a085d64463
This commit is contained in:
Ran Chen 2020-09-02 13:16:06 -07:00 committed by TensorFlower Gardener
parent edb454c157
commit 9da4a05369
14 changed files with 340 additions and 273 deletions

View File

@ -379,7 +379,6 @@ cc_library(
hdrs = ["collective_param_resolver_local.h"],
copts = tf_copts(),
deps = [
":device",
":device_mgr",
"//tensorflow/core:framework",
"//tensorflow/core:lib",

View File

@ -298,8 +298,8 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
}
void BaseCollectiveExecutor::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
StatusCallback done) {
const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr, StatusCallback done) {
cp->instance.gpu_ring_order = *gpu_ring_order_;
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
auto done_with_timeout = done;

View File

@ -113,7 +113,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params,
const string& exec_key, StatusCallback done) override;
void CompleteParamsAsync(const string& device, CollectiveParams* cp,
void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr,
StatusCallback done) override;

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <stddef.h>
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "tensorflow/core/common_runtime/device_mgr.h"
@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
@ -74,12 +75,21 @@ const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
return "undef";
}
}
string TaskNameFromDeviceName(const string& device_name) {
DeviceNameUtils::ParsedName parsed_device;
CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device));
string task_name;
CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name));
return task_name;
}
} // namespace
void CollectiveParamResolverLocal::CompleteGroupLocal(
const string& device, CollectiveParams* cp, const GroupRecCallback& done) {
VLOG(1) << "CompleteGroupLocal device=" << device << " cp: " << cp << ": "
<< cp->ToString();
const DeviceAttributes& device, CollectiveParams* cp,
const GroupRecCallback& done) {
VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp
<< ": " << cp->ToString();
std::vector<StatusCallback> to_be_called;
GroupRec* gr = nullptr;
Status status;
@ -139,13 +149,13 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
// status.
VLOG(2) << "gr device_type=" << gr->group.device_type
<< " cp device_type=" << cp->group.device_type
<< " current device=" << device;
<< " current device=" << device.name();
if (gr->status.ok()) {
// Check for consistency with existing GroupRec.
if (cp->group.device_type != gr->group.device_type) {
gr->status = errors::Internal(
"Collective Op ", cp->name, " is assigned to device ", device,
" with type ", cp->group.device_type.type_string(),
"Collective Op ", cp->name, " is assigned to device ",
device.name(), " with type ", cp->group.device_type.type_string(),
" and group_key ", cp->group.group_key, " but that group has type ",
gr->group.device_type.type_string());
} else if (cp->group.group_size != gr->group.group_size) {
@ -157,38 +167,47 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
}
if (gr->status.ok()) {
// Insert device if not already present.
auto it = gr->device_set.find(device);
if (it == gr->device_set.end()) {
if (gr->device_set.size() == gr->group.group_size) {
auto it = gr->devices.find(device.name());
if (it == gr->devices.end()) {
if (gr->devices.size() == gr->group.group_size) {
// The group is already full.
gr->status = errors::Internal(
"Collective Op ", cp->name, " is assigned to device ", device,
" and group_key ", cp->group.group_key,
"Collective Op ", cp->name, " is assigned to device ",
device.name(), " and group_key ", cp->group.group_key,
" but that group doesn't contain that device.");
} else {
// This is a new device that has not yet joined the group.
gr->device_set.insert(device);
gr->device_list.push_back(device);
DeviceNameUtils::ParsedName parsed_device;
DeviceNameUtils::ParseFullName(device, &parsed_device);
string task_name = strings::StrCat("/job:", parsed_device.job,
"/replica:", parsed_device.replica,
"/task:", parsed_device.task);
gr->task_set.insert(task_name);
gr->task_list.push_back(task_name);
gr->group.num_tasks = static_cast<int32>(gr->task_set.size());
gr->devices[device.name()] = device;
if (gr->devices.size() == gr->group.group_size) {
// The group is full after adding this device, calculate the number
// of tasks.
std::unordered_set<string> tasks;
for (const auto& item : gr->devices) {
tasks.insert(TaskNameFromDeviceName(item.first));
}
gr->group.num_tasks = static_cast<int32>(tasks.size());
}
if (VLOG_IS_ON(1)) {
string dev_buf;
for (const auto& d : gr->device_set) {
strings::StrAppend(&dev_buf, ",", d);
for (const auto& d : gr->devices) {
strings::StrAppend(&dev_buf, ",", d.first);
}
VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
<< " group_size=" << gr->group.group_size << " (current"
<< " devices)=(" << dev_buf << ") (number of"
<< " devices pending)="
<< (gr->group.group_size - gr->device_set.size());
<< (gr->group.group_size - gr->devices.size());
}
}
} else {
// If the device already exists, check if the incarnation matches.
if (it->second.incarnation() != device.incarnation()) {
gr->status = errors::FailedPrecondition(
"Device ", device.name(),
" current incarnation doesn't match with one in the group. This "
"usually means this worker has restarted but the collective "
"leader hasn't, or this worker connects to a wrong cluster.");
}
}
}
@ -196,13 +215,13 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
cp->group.runtime_details = gr->group.runtime_details;
// If the group is not yet complete, queue to wait for it.
VLOG(2) << "group_size " << gr->group.group_size << " set size "
<< gr->device_set.size() << " gr " << gr;
<< gr->devices.size() << " gr " << gr;
if (gr->device_set.size() < gr->group.group_size) {
if (gr->devices.size() < gr->group.group_size) {
gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
return;
}
CHECK_EQ(gr->device_set.size(), gr->group.group_size);
CHECK_EQ(gr->devices.size(), gr->group.group_size);
}
// At this point, we either have a full group, or an error status. Ensure
// that all callbacks are invoked with the appropriate status.
@ -481,10 +500,15 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
{
mutex_lock gl(gr->mu);
ir->shared.group = gr->group;
ir->shared.instance.device_names.assign(gr->device_list.begin(),
gr->device_list.end());
ir->shared.instance.task_names.assign(gr->task_list.begin(),
gr->task_list.end());
ir->shared.instance.device_names.clear();
ir->shared.instance.task_names.clear();
ir->shared.instance.device_names.reserve(gr->devices.size());
ir->shared.instance.task_names.reserve(gr->devices.size());
for (const auto& item : gr->devices) {
ir->shared.instance.device_names.push_back(item.first);
ir->shared.instance.task_names.push_back(
TaskNameFromDeviceName(item.first));
}
VLOG(2) << "Initialized names for instance: "
<< ir->shared.instance.ToString();
}
@ -682,15 +706,15 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
}
void CollectiveParamResolverLocal::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
const StatusCallback& done) {
VLOG(1) << "CompleteParams local " << device << " for " << cp << ": "
const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr, const StatusCallback& done) {
VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": "
<< cp->ToString();
CompleteGroupLocal(
device, cp,
[this, device, cp, done](const Status& s, const GroupRec* gr) {
if (s.ok()) {
CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
CompleteInstanceLocal(device.name(), gr, cp, cp->is_source, done);
} else {
done(s);
}

View File

@ -19,9 +19,11 @@ limitations under the License.
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/thread_annotations.h"
@ -45,7 +47,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
~CollectiveParamResolverLocal() override {}
void CompleteParamsAsync(const string& device, CollectiveParams* cp,
void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr,
const StatusCallback& done) override;
@ -70,10 +72,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
CollGroupParams group;
mutable mutex mu;
Status status TF_GUARDED_BY(mu);
std::set<string> device_set TF_GUARDED_BY(mu);
std::vector<string> device_list TF_GUARDED_BY(mu);
std::set<string> task_set TF_GUARDED_BY(mu);
std::vector<string> task_list TF_GUARDED_BY(mu);
std::unordered_map<string, DeviceAttributes> devices TF_GUARDED_BY(mu);
std::vector<StatusCallback> waiting TF_GUARDED_BY(mu);
};
@ -85,7 +84,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
// callback.
typedef std::function<void(const Status& s, const GroupRec* gr)>
GroupRecCallback;
void CompleteGroupLocal(const string& device, CollectiveParams* cp,
void CompleteGroupLocal(const DeviceAttributes& device, CollectiveParams* cp,
const GroupRecCallback& done)
TF_LOCKS_EXCLUDED(group_mu_);

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -86,6 +87,12 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
}
}
DeviceAttributes GetDeviceAttributes(const string& device_name) {
Device* device = nullptr;
TF_CHECK_OK(device_mgr_->LookupDevice(device_name, &device));
return device->attributes();
}
string task_name_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
@ -187,12 +194,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
cp->instance.impl_details.subdiv_offsets.push_back(0);
cp->is_source = false;
Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
nullptr /*CancellationManager*/,
[&statuses, &note, i](const Status& s) {
statuses[i] = s;
note[i].Notify();
});
prl_->CompleteParamsAsync(
GetDeviceAttributes(cp->instance.device_names[0]), cp,
nullptr /*CancellationManager*/,
[&statuses, &note, i](const Status& s) {
statuses[i] = s;
note[i].Notify();
});
});
}
for (int i = 0; i < NUM_DEVS; ++i) {
@ -240,12 +248,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
CollectiveParams* cp = &cps[i];
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp);
Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
nullptr /*CancellationManager*/,
[&statuses, &note, i](const Status& s) {
statuses[i] = s;
note[i].Notify();
});
prl_->CompleteParamsAsync(
GetDeviceAttributes(cp->instance.device_names[0]), cp,
nullptr /*CancellationManager*/,
[&statuses, &note, i](const Status& s) {
statuses[i] = s;
note[i].Notify();
});
});
}
for (int i = 0; i < NUM_DEVS; ++i) {
@ -278,12 +287,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) {
CollectiveParams* cp = &cps[i];
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp);
Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
nullptr /*CancellationManager*/,
[&statuses, &note, i](const Status& s) {
statuses[i] = s;
note[i].Notify();
});
prl_->CompleteParamsAsync(
GetDeviceAttributes(cp->instance.device_names[0]), cp,
nullptr /*CancellationManager*/,
[&statuses, &note, i](const Status& s) {
statuses[i] = s;
note[i].Notify();
});
});
}
for (int i = 0; i < NUM_DEVS; ++i) {
@ -326,8 +336,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100,
/*is_source*/ i == 0);
prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr,
[&done](const Status& s) {
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
&cancel_mgr, [&done](const Status& s) {
EXPECT_EQ(s.code(), error::ABORTED);
EXPECT_EQ(s.error_message(), "__aborted__");
done.DecrementCount();
@ -355,8 +365,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(group_key, instance_key,
/*is_source*/ i == 0);
prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr,
[&done](const Status& s) {
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
&cancel_mgr, [&done](const Status& s) {
EXPECT_EQ(s.code(), error::OK);
done.DecrementCount();
});
@ -373,12 +383,13 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(group_key, instance_key + 1,
/*is_source*/ i == 0);
prl_->CompleteParamsAsync(
device, &cp[i], &cancel_mgr, [&done](const Status& s) {
EXPECT_EQ(s.code(), error::ABORTED);
EXPECT_EQ(s.error_message(), "__aborted__");
done.DecrementCount();
});
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
&cancel_mgr, [&done](const Status& s) {
EXPECT_EQ(s.code(), error::ABORTED);
EXPECT_EQ(s.error_message(),
"__aborted__");
done.DecrementCount();
});
start.DecrementCount();
});
}
@ -402,8 +413,8 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(group_key, instance_key,
/*is_source*/ i == 0);
prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr,
[&done](const Status& s) {
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
&cancel_mgr, [&done](const Status& s) {
EXPECT_EQ(s.code(), error::OK);
done.DecrementCount();
});
@ -418,7 +429,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
Notification done;
auto cp = MakeCollectiveParams(group_key, instance_key,
/*is_source*/ true);
prl_->CompleteParamsAsync(device, &cp, &cancel_mgr,
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, &cancel_mgr,
[&done](const Status& s) {
EXPECT_EQ(s.code(), error::ABORTED);
EXPECT_EQ(s.error_message(), "__aborted__");
@ -457,7 +468,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortNormalCompleteParamsAsync) {
auto cp =
MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key,
/*is_source*/ i == 0);
prl_->CompleteParamsAsync(device, &cp, &cancel_mgr,
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp,
&cancel_mgr,
[&status, &n](const Status& s) {
status = s;
n.Notify();

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace tensorflow {
@ -35,7 +36,7 @@ class TestCollectiveExecutor : public CollectiveExecutor {
};
class TestParamResolver : public ParamResolverInterface {
void CompleteParamsAsync(const string& device, CollectiveParams* cp,
void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr,
const StatusCallback& done) override {
done(errors::Internal("Unimplemented"));

View File

@ -571,7 +571,9 @@ cc_library(
":device_resolver_distributed",
":worker_cache",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/strings",
],
)
@ -606,6 +608,7 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:collective_ops",
"@com_google_absl//absl/container:flat_hash_map",
],
)

View File

@ -18,14 +18,18 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
namespace {
class CompleteGroupCall : public CancellableCall {
public:
CompleteGroupCall(const CollGroupParams& group, const string& device_name,
CompleteGroupCall(const CollGroupParams& group,
const DeviceAttributes& device,
const CollectiveType& collective_type,
CancellationManager* cancel_mgr,
const string& remote_worker, WorkerCacheInterface* wc)
@ -33,7 +37,7 @@ class CompleteGroupCall : public CancellableCall {
req_.set_group_key(group.group_key);
req_.set_group_size(group.group_size);
req_.set_device_type(group.device_type.type_string());
req_.add_device_name(device_name);
*req_.mutable_device_attributes() = device;
req_.set_collective_type(collective_type);
}
~CompleteGroupCall() override {}
@ -98,16 +102,16 @@ CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
}
void CollectiveParamResolverDistributed::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
const StatusCallback& done) {
VLOG(1) << "CompleteParams distributed " << device << " for " << cp << ": "
<< cp->ToString();
const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr, const StatusCallback& done) {
VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
<< ": " << cp->ToString();
CompleteGroupDistributed(device, cp, cancel_mgr,
[this, device, cp, cancel_mgr, done](
const Status& s, const GroupRec* gr) {
if (s.ok()) {
CompleteInstanceDistributed(device, gr, cp,
cancel_mgr, done);
CompleteInstanceDistributed(
device.name(), gr, cp, cancel_mgr, done);
} else {
done(s);
}
@ -117,28 +121,28 @@ void CollectiveParamResolverDistributed::CompleteParamsAsync(
void CollectiveParamResolverDistributed::CompleteGroupAsync(
const CompleteGroupRequest* request, CompleteGroupResponse* response,
CancellationManager* cancel_mgr, const StatusCallback& done) {
if (!request->has_device_attributes()) {
done(errors::Internal(
"CompleteGroupRequest device_attributes is not set. Make sure you're "
"running the same version of Tensorflow on all workers."));
return;
}
CollectiveParams cp;
cp.group.group_key = request->group_key();
cp.group.group_size = request->group_size();
cp.group.device_type = DeviceType(request->device_type());
for (const string& dn : request->device_name()) {
cp.instance.device_names.push_back(dn);
}
cp.instance.type = CollectiveType(request->collective_type());
CompleteGroupDistributed(
cp.instance.device_names[0], &cp, cancel_mgr,
request->device_attributes(), &cp, cancel_mgr,
[response, done](const Status& s, const GroupRec* gr) {
if (s.ok()) {
mutex_lock l(gr->mu);
response->set_group_key(gr->group.group_key);
response->set_group_size(gr->group.group_size);
response->set_device_type(gr->group.device_type.type_string());
response->set_num_tasks(gr->task_set.size());
for (const string& dn : gr->device_list) {
response->add_device_name(dn);
}
for (const string& tn : gr->task_list) {
response->add_task_name(tn);
response->set_num_tasks(gr->group.num_tasks);
for (const auto& item : gr->devices) {
*response->add_device_attributes() = item.second;
}
response->set_communicator_key(
gr->group.runtime_details.communicator_key);
@ -152,6 +156,22 @@ void CollectiveParamResolverDistributed::CompleteGroupAsync(
void CollectiveParamResolverDistributed::CompleteInstanceAsync(
const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
CancellationManager* cancel_mgr, const StatusCallback& done) {
GroupRec* gr = GetCachedGroup(request->group_key());
if (gr == nullptr) {
done(errors::FailedPrecondition(
"group ", request->group_key(),
" not found. This normally means the server has restarted"));
return;
}
{
mutex_lock l(gr->mu);
if (!gr->status.ok() || gr->devices.size() != gr->group.group_size) {
done(errors::FailedPrecondition(
"group ", request->group_key(),
" failed to resolve. This normally means the server has restarted"));
return;
}
}
CollectiveParams* cp = new CollectiveParams;
cp->name = request->name();
cp->group.group_key = request->group_key();
@ -164,56 +184,44 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync(
for (int32 offset : request->subdiv_offset()) {
cp->instance.impl_details.subdiv_offsets.push_back(offset);
}
string* device = new string(request->device());
VLOG(1) << "New cp " << cp << " for device " << *device << " : "
<< cp->ToString();
StatusCallback done_and_cleanup = [cp, device, done](const Status& s) {
StatusCallback done_and_cleanup = [cp, done](const Status& s) {
done(s);
delete cp;
delete device;
};
// Start by completing the group.
CompleteGroupDistributed(
*device, cp, cancel_mgr,
[this, cp, device, response, cancel_mgr, done_and_cleanup](
const Status& cg_status, const GroupRec* gr) {
if (cg_status.ok()) {
// Then complete the instance.
CompleteInstanceDistributed(
*device, gr, cp, cancel_mgr,
[this, gr, cp, response,
done_and_cleanup](const Status& ci_status) {
if (ci_status.ok()) {
// Now source_rank should be known, so
// retrieve it.
FindInstanceRec(
gr, cp,
[cp, response, done_and_cleanup](const Status& fi_status,
InstanceRec* ir) {
if (fi_status.ok()) {
mutex_lock l(ir->out_mu);
ir->WaitForOutMu(l);
response->set_instance_key(cp->instance.instance_key);
response->set_source_rank(ir->source_rank);
done_and_cleanup(fi_status);
} else {
done_and_cleanup(fi_status);
}
});
CompleteInstanceDistributed(
request->device(), gr, cp, cancel_mgr,
[this, gr, cp, response, done_and_cleanup](const Status& ci_status) {
if (ci_status.ok()) {
// Now source_rank should be known, so
// retrieve it.
FindInstanceRec(
gr, cp,
[cp, response, done_and_cleanup](const Status& fi_status,
InstanceRec* ir) {
if (fi_status.ok()) {
mutex_lock l(ir->out_mu);
ir->WaitForOutMu(l);
response->set_instance_key(cp->instance.instance_key);
response->set_source_rank(ir->source_rank);
done_and_cleanup(fi_status);
} else {
done_and_cleanup(ci_status);
done_and_cleanup(fi_status);
}
});
} else {
done_and_cleanup(cg_status);
done_and_cleanup(ci_status);
}
});
}
bool CollectiveParamResolverDistributed::GroupIsCached(int32 group_key) {
CollectiveParamResolverDistributed::GroupRec*
CollectiveParamResolverDistributed::GetCachedGroup(int32 group_key) {
mutex_lock l(group_mu_);
const auto& it = group_table_.find(group_key);
return it != group_table_.end();
auto it = group_table_.find(group_key);
if (it == group_table_.end()) {
return nullptr;
}
return it->second.get();
}
Status CollectiveParamResolverDistributed::UpdateGroupCache(
@ -226,26 +234,19 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache(
gr->group.group_key = resp.group_key();
gr->group.group_size = resp.group_size();
gr->group.num_tasks = resp.num_tasks();
if (resp.device_name_size() != gr->group.group_size) {
if (resp.device_attributes().empty()) {
return errors::Internal(
"CompleteGroupResponse device_attributes is empty. Make sure you're "
"running the same version of Tensorflow on all workers.");
}
if (resp.device_attributes_size() != gr->group.group_size) {
return errors::Internal(
"CompleteGroupResponse group_size doesn't match device_name list");
}
for (const string& dn : resp.device_name()) {
gr->device_set.insert(dn);
gr->device_list.push_back(dn);
for (const DeviceAttributes& device : resp.device_attributes()) {
gr->devices[device.name()] = device;
}
if (resp.task_name_size() != gr->group.group_size) {
return errors::Internal(
"CompleteGroupResponse group_size doesn't match task_name list");
}
for (const string& tn : resp.task_name()) {
gr->task_list.push_back(tn);
gr->task_set.insert(tn);
}
CHECK_EQ(gr->task_set.size(), gr->group.num_tasks);
gr->group.runtime_details.communicator_key = resp.communicator_key();
VLOG(2) << "Group communicator_key="
<< absl::CEscape(gr->group.runtime_details.communicator_key);
}
{
// Group membership should never change. Once a record is in group_table_
@ -273,14 +274,15 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache(
}
void CollectiveParamResolverDistributed::CompleteGroupDistributed(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
const GroupRecCallback& done) {
const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr, const GroupRecCallback& done) {
VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
<< " dev: " << device << " is_leader=" << (group_leader_.empty());
<< " dev: " << device.name()
<< " is_leader=" << (group_leader_.empty());
if (group_leader_.empty()) {
// This is the group leader, so resolution is local.
return CompleteGroupLocal(device, cp, done);
} else if (!GroupIsCached(cp->group.group_key)) {
} else if (GetCachedGroup(cp->group.group_key) == nullptr) {
// Need to update Group cache from the leader.
CompleteGroupCall* call =
new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr,

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow {
class ConfigProto;
@ -31,7 +32,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
WorkerCacheInterface* worker_cache,
const string& task_name);
void CompleteParamsAsync(const string& device, CollectiveParams* cp,
void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr,
const StatusCallback& done) override;
@ -46,9 +47,9 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
const StatusCallback& done) override;
protected:
// Returns true iff there's an entry for this group_key in the
// local group_table_.
bool GroupIsCached(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_);
// Returns the cached group iff there's an entry for this group_key in the
// local group_table_; returns nullptr otherwise.
GroupRec* GetCachedGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_);
// Updates group_table_ with contents of resp.
Status UpdateGroupCache(const CompleteGroupResponse& resp)
@ -59,7 +60,8 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
//
// Semantics are like those of CompleteGroupLocal but will make a
// remote call to the group leader if necessary.
void CompleteGroupDistributed(const string& device, CollectiveParams* cp,
void CompleteGroupDistributed(const DeviceAttributes& device,
CollectiveParams* cp,
CancellationManager* cancel_mgr,
const GroupRecCallback& done);

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/test_utils.h"
@ -23,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/device_name_utils.h"
@ -41,6 +43,7 @@ static std::unique_ptr<Device> NewDevice(const string& type,
attr.set_name(name);
attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(3); // a non-default value
attr.set_incarnation(random::New64());
return absl::make_unique<FakeDevice>(attr);
}
@ -125,127 +128,110 @@ class FakeCache : public TestWorkerCache {
class DeviceResDistTest : public ::testing::Test {
protected:
DeviceResDistTest() {}
~DeviceResDistTest() override {
for (DeviceMgr* dm : device_mgrs_) {
delete dm;
}
for (auto it : dev_resolvers_) {
delete it.second;
}
for (auto it : cp_resolvers_) {
delete it.second;
}
for (FakeWorker* w : workers_) {
delete w;
}
}
void DefineWorkers(int num_workers, int num_devices,
const string& device_type, bool nccl) {
ConfigProto config;
for (int w = 0; w < num_workers; ++w) {
string name = strings::StrCat("/job:worker/replica:0/task:", w);
if (w == 0) {
config.mutable_experimental()->set_collective_group_leader(name);
if (nccl) {
config.mutable_experimental()->set_collective_nccl(true);
}
}
DefineWorker(config, name, device_type, num_devices);
DefineWorker(name, device_type, num_devices, nccl);
}
}
void DefineWorker(const ConfigProto& config, const string& worker_name,
const string& device_type, int num_devices) {
void DefineWorker(const string& worker_name, const string& device_type,
int num_devices, bool nccl) {
ConfigProto config;
config.mutable_experimental()->set_collective_group_leader(
"/job:worker/replica:0/task:0");
config.mutable_experimental()->set_collective_nccl(nccl);
std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice(
device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
}
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr);
device_mgrs_[worker_name] =
absl::make_unique<StaticDeviceMgr>(std::move(devices));
std::vector<string>* dv = &dev_by_task_[worker_name];
for (auto* d : dev_mgr->ListDevices()) {
dv->clear();
for (auto* d : device_mgrs_[worker_name]->ListDevices()) {
dv->push_back(d->name());
}
DeviceResolverDistributed* dev_res =
new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
dev_resolvers_[worker_name] = dev_res;
CollectiveParamResolverDistributed* cp_res =
new CollectiveParamResolverDistributed(config, dev_mgr, dev_res, &wc_,
worker_name);
cp_resolvers_[worker_name] = cp_res;
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, cp_res);
workers_.push_back(fw);
wc_.AddWorker(worker_name, fw);
dev_resolvers_[worker_name] = absl::make_unique<DeviceResolverDistributed>(
device_mgrs_[worker_name].get(), &wc_, worker_name);
cp_resolvers_[worker_name] =
absl::make_unique<CollectiveParamResolverDistributed>(
config, device_mgrs_[worker_name].get(),
dev_resolvers_[worker_name].get(), &wc_, worker_name);
workers_[worker_name] = absl::make_unique<FakeWorker>(
worker_name, device_mgrs_[worker_name].get(),
cp_resolvers_[worker_name].get());
wc_.AddWorker(worker_name, workers_[worker_name].get());
}
void DefineCollectiveParams(int num_workers, int num_devices) {
const int kGroupKey = 5;
const int kInstanceKey = 3;
void DefineCollectiveParams(int num_workers, int num_devices,
const string& device_type) {
for (int wi = 0; wi < num_workers; ++wi) {
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
for (int di = 0; di < num_devices; ++di) {
string device_name = strings::StrCat(task_name, "/device:CPU:", di);
cp_.push_back(CollectiveParams());
CollectiveParams& cp = cp_.back();
cp.group.group_key = kGroupKey;
cp.group.group_size = num_workers * num_devices;
cp.group.device_type = DEVICE_CPU;
cp.group.num_tasks = num_workers;
cp.instance.instance_key = kInstanceKey;
cp.instance.type = REDUCTION_COLLECTIVE;
cp.instance.data_type = DT_FLOAT;
cp.instance.shape = TensorShape({64});
cp.instance.impl_details.subdiv_offsets.push_back(0);
string device_name =
strings::StrCat(task_name, "/device:", device_type, ":", di);
cp_[device_name] =
CreateCollectiveParams(num_workers, num_devices, device_type);
}
}
}
CollectiveParams CreateCollectiveParams(int num_workers, int num_devices,
const string& device_type) {
const int kGroupKey = 5;
const int kInstanceKey = 3;
CollectiveParams cp;
cp.group.group_key = kGroupKey;
cp.group.group_size = num_workers * num_devices;
cp.group.device_type = DeviceType(device_type);
cp.group.num_tasks = num_workers;
cp.instance.instance_key = kInstanceKey;
cp.instance.type = REDUCTION_COLLECTIVE;
cp.instance.data_type = DT_FLOAT;
cp.instance.shape = TensorShape({64});
cp.instance.impl_details.subdiv_offsets.push_back(0);
return cp;
}
void IssueRequests(int num_workers, int num_devices) {
const int device_count = num_workers * num_devices;
{
mutex_lock l(mu_);
num_done_ = 0;
}
cp_.resize(device_count);
status_.resize(device_count);
int idx = 0;
int group_size = num_workers * num_devices;
for (int wi = 0; wi < num_workers; ++wi) {
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
for (int di = 0; di < num_devices; ++di) {
IssueRequest(num_workers, num_devices, idx);
++idx;
string device_name = strings::StrCat(task_name, "/device:CPU:", di);
IssueRequest(task_name, device_name, group_size);
}
}
}
void IssueRequest(int num_workers, int num_devices, int idx) {
int device_count = num_workers * num_devices;
int wi = idx / num_devices;
int di = idx % num_devices;
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
string device_name = strings::StrCat(task_name, "/device:CPU:", di);
while (idx >= cp_.size()) {
status_.resize(idx + 1);
cp_.resize(idx + 1);
}
CollectiveParams* cp = &cp_[idx];
CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name];
void IssueRequest(const string& task_name, const string& device_name,
int group_size) {
Device* device = nullptr;
TF_CHECK_OK(device_mgrs_[task_name]->LookupDevice(device_name, &device));
CollectiveParams* cp = &cp_[device_name];
CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name].get();
CHECK(cp_res);
cp_res->CompleteParamsAsync(device_name, cp, &cm_,
[this, idx, device_count](const Status& s) {
status_[idx] = s;
{
mutex_lock l(mu_);
++num_done_;
if (num_done_ == device_count) {
done_.notify_all();
}
}
});
cp_res->CompleteParamsAsync(
device->attributes(), cp, &cm_,
[this, device_name, group_size](const Status& s) {
status_[device_name] = s;
{
mutex_lock l(mu_);
++num_done_;
if (num_done_ == group_size) {
done_.notify_all();
}
}
});
}
void ValidateCollectiveParams(int num_workers, int num_devices) {
@ -259,39 +245,59 @@ class DeviceResDistTest : public ::testing::Test {
// Verify that all cp_ values get the same set of task and device
// names, with unique default_rank in the expected order.
const int dev_count = num_workers * num_devices;
string dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
for (int wi = 0; wi < num_workers; ++wi) {
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
for (int di = 0; di < num_devices; ++di) {
string device_name = strings::StrCat(task_name, "/device:CPU:", di);
int idx = wi * num_devices + di;
TF_ASSERT_OK(status_[idx]);
EXPECT_EQ(cp_[idx].default_rank, idx);
EXPECT_EQ(cp_[idx].instance.device_names.size(), dev_count);
EXPECT_EQ(cp_[idx].instance.device_names[idx], device_name);
EXPECT_EQ(cp_[idx].instance.task_names[idx], task_name);
TF_ASSERT_OK(status_[device_name]);
EXPECT_EQ(cp_[device_name].default_rank, idx);
EXPECT_EQ(cp_[device_name].instance.device_names.size(), dev_count);
EXPECT_EQ(cp_[device_name].instance.device_names[idx], device_name);
EXPECT_EQ(cp_[device_name].instance.task_names[idx], task_name);
if (idx > 0) {
EXPECT_EQ(cp_[0].group.runtime_details.communicator_key,
cp_[idx].group.runtime_details.communicator_key);
EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key,
cp_[device_name].group.runtime_details.communicator_key);
for (int i = 0; i < dev_count; ++i) {
EXPECT_EQ(cp_[0].instance.device_names[i],
cp_[idx].instance.device_names[i]);
EXPECT_EQ(cp_[0].instance.task_names[i],
cp_[idx].instance.task_names[i]);
EXPECT_EQ(cp_[dev0].instance.device_names[i],
cp_[device_name].instance.device_names[i]);
EXPECT_EQ(cp_[dev0].instance.task_names[i],
cp_[device_name].instance.task_names[i]);
}
}
}
}
}
void RestartWorker(int worker_idx, int num_workers, int num_devices,
const string& device_type, bool nccl) {
string worker_name =
strings::StrCat("/job:worker/replica:0/task:", worker_idx);
DefineWorker(worker_name, device_type, num_devices, nccl);
for (int i = 0; i < num_devices; ++i) {
string device_name =
strings::StrCat(worker_name, "/device:", device_type, ":", i);
cp_[device_name] =
CreateCollectiveParams(num_workers, num_devices, device_type);
status_.erase(device_name);
}
}
FakeCache wc_;
CancellationManager cm_;
std::vector<DeviceMgr*> device_mgrs_;
std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
std::unordered_map<string, CollectiveParamResolverDistributed*> cp_resolvers_;
std::unordered_map<string, std::vector<string>> dev_by_task_;
std::vector<FakeWorker*> workers_;
std::vector<CollectiveParams> cp_;
std::vector<Status> status_;
// Below are keyed by task names.
absl::flat_hash_map<string, std::unique_ptr<DeviceMgr>> device_mgrs_;
absl::flat_hash_map<string, std::unique_ptr<DeviceResolverDistributed>>
dev_resolvers_;
absl::flat_hash_map<string,
std::unique_ptr<CollectiveParamResolverDistributed>>
cp_resolvers_;
absl::flat_hash_map<string, std::vector<string>> dev_by_task_;
absl::flat_hash_map<string, std::unique_ptr<FakeWorker>> workers_;
// Below are keyed by device names;
absl::flat_hash_map<string, CollectiveParams> cp_;
absl::flat_hash_map<string, Status> status_;
mutex mu_;
int num_done_ TF_GUARDED_BY(mu_);
condition_variable done_;
@ -300,8 +306,8 @@ class DeviceResDistTest : public ::testing::Test {
TEST_F(DeviceResDistTest, Workers1Devices1) {
const int num_workers = 1;
const int num_devices = 1;
DefineWorkers(num_workers, num_devices, "CPU", false);
DefineCollectiveParams(num_workers, num_devices);
DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
DefineCollectiveParams(num_workers, num_devices, "CPU");
IssueRequests(num_workers, num_devices);
ValidateCollectiveParams(num_workers, num_devices);
}
@ -309,12 +315,25 @@ TEST_F(DeviceResDistTest, Workers1Devices1) {
TEST_F(DeviceResDistTest, Workers2Devices2) {
const int num_workers = 2;
const int num_devices = 2;
DefineWorkers(num_workers, num_devices, "CPU", false);
DefineCollectiveParams(num_workers, num_devices);
DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
DefineCollectiveParams(num_workers, num_devices, "CPU");
IssueRequests(num_workers, num_devices);
ValidateCollectiveParams(num_workers, num_devices);
}
TEST_F(DeviceResDistTest, DifferentIncarnation) {
const int num_workers = 2;
const int num_devices = 1;
DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
DefineCollectiveParams(num_workers, num_devices, "CPU");
IssueRequests(num_workers, num_devices);
RestartWorker(1, num_workers, num_devices, "CPU", /*nccl*/ false);
const string task_name = "/job:worker/replica:0/task:1";
const string device_name = absl::StrCat(task_name, "/device:CPU:0");
IssueRequest(task_name, device_name, num_workers * num_devices);
EXPECT_TRUE(errors::IsFailedPrecondition(status_[device_name]));
}
#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
namespace {
// A mock NcclReducer for testing group runtime details initialization with CPU
@ -347,7 +366,7 @@ TEST_F(DeviceResDistTest, Workers4Devices3) {
const int num_workers = 4;
const int num_devices = 3;
DefineWorkers(num_workers, num_devices, "CPU", true);
DefineCollectiveParams(num_workers, num_devices);
DefineCollectiveParams(num_workers, num_devices, "CPU");
IssueRequests(num_workers, num_devices);
ValidateCollectiveParams(num_workers, num_devices);
}

View File

@ -180,7 +180,8 @@ class ParamResolverInterface {
// Called by each collective op at first execution in order to fill out
// the CollectiveParams structure with data gathered from the full
// (maybe distributed) collection of peer nodes.
virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
virtual void CompleteParamsAsync(const DeviceAttributes& device,
CollectiveParams* cp,
CancellationManager* cancel_mgr,
const StatusCallback& done) = 0;
@ -301,7 +302,8 @@ class CollectiveExecutor : public core::RefCounted {
"a CollectiveExecutor has not been provided."));
}
virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
virtual void CompleteParamsAsync(const DeviceAttributes& device,
CollectiveParams* cp,
CancellationManager* cancel_mgr,
StatusCallback done) {
done(errors::Internal(

View File

@ -73,7 +73,7 @@ class CollectiveOpKernel : public AsyncOpKernel {
<< " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->name(), &col_params_, c->cancellation_manager(),
c->device()->attributes(), &col_params_, c->cancellation_manager(),
[this, c, done](const Status& s) {
if (s.ok()) {
col_params_.instance.impl_details.dependencies = dependencies_;
@ -538,7 +538,8 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
<< " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->name(), col_params.get(), c->cancellation_manager(),
c->device()->attributes(), col_params.get(),
c->cancellation_manager(),
[c, done = std::move(done), col_params, col_exec](const Status& s) {
if (s.ok()) {
auto actual_done = [c, group_key = col_params->group.group_key,

View File

@ -545,8 +545,10 @@ message CompleteGroupRequest {
int32 group_key = 1;
int32 group_size = 2;
string device_type = 3;
repeated string device_name = 4;
int32 collective_type = 5;
DeviceAttributes device_attributes = 6;
reserved 4;
}
// Gives the complete membership of the group identified by group_key.
@ -555,9 +557,10 @@ message CompleteGroupResponse {
int32 group_size = 2;
string device_type = 3;
int32 num_tasks = 4; // number of distinct tasks hosting the devices
repeated string device_name = 5;
repeated string task_name = 6; // task name prefixes of device_names
bytes communicator_key = 7;
repeated DeviceAttributes device_attributes = 8;
reserved 5, 6;
}
// Supplies data about one collective op belonging to the instance identified