Exchange device attributes at group resolution again

Previously CollectiveParamResolver queries device attributes when initializing instance params. That has issues when the collective leader fails and restarts quickly between group resolution and instance resolution. In such case, all other workers get the incarnation of the restarted leader, thus they're unable to detect that the leader has failed; the leader will deadlock on the group resolution.

This change doesn't fully fixed the issue because it only exchanges device attributes at group resolution, but doesn't populate the device attributes to DeviceResolver. That will be done in a following change.

This change also changes the behavior when a non-leader fails and restarts. Previously it gets the cached group resolution from the leader, now it will get an error because its incarnation doesn't match with the one in the cached group parameters. This should have no actual effect since that worker will always restart again after the leader has restarted.

This change changes both the client and server without being backward compatible. It assumes that client and server are running the same version of Tensorflow. This should be true since the only way to use CollectiveParamResolverDistributed is through MultiWorkerMirroredStrategy (MWMS). For MWMS, all workers should run the same version of the program.

PiperOrigin-RevId: 329774332
Change-Id: I6f3ba535cefd3a8ec321e4138436b6a085d64463
This commit is contained in:
Ran Chen 2020-09-02 13:16:06 -07:00 committed by TensorFlower Gardener
parent edb454c157
commit 9da4a05369
14 changed files with 340 additions and 273 deletions

View File

@ -379,7 +379,6 @@ cc_library(
hdrs = ["collective_param_resolver_local.h"], hdrs = ["collective_param_resolver_local.h"],
copts = tf_copts(), copts = tf_copts(),
deps = [ deps = [
":device",
":device_mgr", ":device_mgr",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",

View File

@ -298,8 +298,8 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
} }
void BaseCollectiveExecutor::CompleteParamsAsync( void BaseCollectiveExecutor::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, const DeviceAttributes& device, CollectiveParams* cp,
StatusCallback done) { CancellationManager* cancel_mgr, StatusCallback done) {
cp->instance.gpu_ring_order = *gpu_ring_order_; cp->instance.gpu_ring_order = *gpu_ring_order_;
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false); const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
auto done_with_timeout = done; auto done_with_timeout = done;

View File

@ -113,7 +113,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params, void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params,
const string& exec_key, StatusCallback done) override; const string& exec_key, StatusCallback done) override;
void CompleteParamsAsync(const string& device, CollectiveParams* cp, void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr, CancellationManager* cancel_mgr,
StatusCallback done) override; StatusCallback done) override;

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <stddef.h> #include <stddef.h>
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_set>
#include <utility> #include <utility>
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/config.pb.h"
@ -74,12 +75,21 @@ const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
return "undef"; return "undef";
} }
} }
string TaskNameFromDeviceName(const string& device_name) {
DeviceNameUtils::ParsedName parsed_device;
CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device));
string task_name;
CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name));
return task_name;
}
} // namespace } // namespace
void CollectiveParamResolverLocal::CompleteGroupLocal( void CollectiveParamResolverLocal::CompleteGroupLocal(
const string& device, CollectiveParams* cp, const GroupRecCallback& done) { const DeviceAttributes& device, CollectiveParams* cp,
VLOG(1) << "CompleteGroupLocal device=" << device << " cp: " << cp << ": " const GroupRecCallback& done) {
<< cp->ToString(); VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp
<< ": " << cp->ToString();
std::vector<StatusCallback> to_be_called; std::vector<StatusCallback> to_be_called;
GroupRec* gr = nullptr; GroupRec* gr = nullptr;
Status status; Status status;
@ -139,13 +149,13 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
// status. // status.
VLOG(2) << "gr device_type=" << gr->group.device_type VLOG(2) << "gr device_type=" << gr->group.device_type
<< " cp device_type=" << cp->group.device_type << " cp device_type=" << cp->group.device_type
<< " current device=" << device; << " current device=" << device.name();
if (gr->status.ok()) { if (gr->status.ok()) {
// Check for consistency with existing GroupRec. // Check for consistency with existing GroupRec.
if (cp->group.device_type != gr->group.device_type) { if (cp->group.device_type != gr->group.device_type) {
gr->status = errors::Internal( gr->status = errors::Internal(
"Collective Op ", cp->name, " is assigned to device ", device, "Collective Op ", cp->name, " is assigned to device ",
" with type ", cp->group.device_type.type_string(), device.name(), " with type ", cp->group.device_type.type_string(),
" and group_key ", cp->group.group_key, " but that group has type ", " and group_key ", cp->group.group_key, " but that group has type ",
gr->group.device_type.type_string()); gr->group.device_type.type_string());
} else if (cp->group.group_size != gr->group.group_size) { } else if (cp->group.group_size != gr->group.group_size) {
@ -157,38 +167,47 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
} }
if (gr->status.ok()) { if (gr->status.ok()) {
// Insert device if not already present. // Insert device if not already present.
auto it = gr->device_set.find(device); auto it = gr->devices.find(device.name());
if (it == gr->device_set.end()) { if (it == gr->devices.end()) {
if (gr->device_set.size() == gr->group.group_size) { if (gr->devices.size() == gr->group.group_size) {
// The group is already full. // The group is already full.
gr->status = errors::Internal( gr->status = errors::Internal(
"Collective Op ", cp->name, " is assigned to device ", device, "Collective Op ", cp->name, " is assigned to device ",
" and group_key ", cp->group.group_key, device.name(), " and group_key ", cp->group.group_key,
" but that group doesn't contain that device."); " but that group doesn't contain that device.");
} else { } else {
// This is a new device that has not yet joined the group. // This is a new device that has not yet joined the group.
gr->device_set.insert(device); gr->devices[device.name()] = device;
gr->device_list.push_back(device); if (gr->devices.size() == gr->group.group_size) {
DeviceNameUtils::ParsedName parsed_device; // The group is full after adding this device, calculate the number
DeviceNameUtils::ParseFullName(device, &parsed_device); // of tasks.
string task_name = strings::StrCat("/job:", parsed_device.job, std::unordered_set<string> tasks;
"/replica:", parsed_device.replica, for (const auto& item : gr->devices) {
"/task:", parsed_device.task); tasks.insert(TaskNameFromDeviceName(item.first));
gr->task_set.insert(task_name); }
gr->task_list.push_back(task_name); gr->group.num_tasks = static_cast<int32>(tasks.size());
gr->group.num_tasks = static_cast<int32>(gr->task_set.size()); }
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
string dev_buf; string dev_buf;
for (const auto& d : gr->device_set) { for (const auto& d : gr->devices) {
strings::StrAppend(&dev_buf, ",", d); strings::StrAppend(&dev_buf, ",", d.first);
} }
VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
<< " group_size=" << gr->group.group_size << " (current" << " group_size=" << gr->group.group_size << " (current"
<< " devices)=(" << dev_buf << ") (number of" << " devices)=(" << dev_buf << ") (number of"
<< " devices pending)=" << " devices pending)="
<< (gr->group.group_size - gr->device_set.size()); << (gr->group.group_size - gr->devices.size());
} }
} }
} else {
// If the device already exists, check if the incarnation matches.
if (it->second.incarnation() != device.incarnation()) {
gr->status = errors::FailedPrecondition(
"Device ", device.name(),
" current incarnation doesn't match with one in the group. This "
"usually means this worker has restarted but the collective "
"leader hasn't, or this worker connects to a wrong cluster.");
}
} }
} }
@ -196,13 +215,13 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
cp->group.runtime_details = gr->group.runtime_details; cp->group.runtime_details = gr->group.runtime_details;
// If the group is not yet complete, queue to wait for it. // If the group is not yet complete, queue to wait for it.
VLOG(2) << "group_size " << gr->group.group_size << " set size " VLOG(2) << "group_size " << gr->group.group_size << " set size "
<< gr->device_set.size() << " gr " << gr; << gr->devices.size() << " gr " << gr;
if (gr->device_set.size() < gr->group.group_size) { if (gr->devices.size() < gr->group.group_size) {
gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr)); gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
return; return;
} }
CHECK_EQ(gr->device_set.size(), gr->group.group_size); CHECK_EQ(gr->devices.size(), gr->group.group_size);
} }
// At this point, we either have a full group, or an error status. Ensure // At this point, we either have a full group, or an error status. Ensure
// that all callbacks are invoked with the appropriate status. // that all callbacks are invoked with the appropriate status.
@ -481,10 +500,15 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
{ {
mutex_lock gl(gr->mu); mutex_lock gl(gr->mu);
ir->shared.group = gr->group; ir->shared.group = gr->group;
ir->shared.instance.device_names.assign(gr->device_list.begin(), ir->shared.instance.device_names.clear();
gr->device_list.end()); ir->shared.instance.task_names.clear();
ir->shared.instance.task_names.assign(gr->task_list.begin(), ir->shared.instance.device_names.reserve(gr->devices.size());
gr->task_list.end()); ir->shared.instance.task_names.reserve(gr->devices.size());
for (const auto& item : gr->devices) {
ir->shared.instance.device_names.push_back(item.first);
ir->shared.instance.task_names.push_back(
TaskNameFromDeviceName(item.first));
}
VLOG(2) << "Initialized names for instance: " VLOG(2) << "Initialized names for instance: "
<< ir->shared.instance.ToString(); << ir->shared.instance.ToString();
} }
@ -682,15 +706,15 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
} }
void CollectiveParamResolverLocal::CompleteParamsAsync( void CollectiveParamResolverLocal::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, const DeviceAttributes& device, CollectiveParams* cp,
const StatusCallback& done) { CancellationManager* cancel_mgr, const StatusCallback& done) {
VLOG(1) << "CompleteParams local " << device << " for " << cp << ": " VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": "
<< cp->ToString(); << cp->ToString();
CompleteGroupLocal( CompleteGroupLocal(
device, cp, device, cp,
[this, device, cp, done](const Status& s, const GroupRec* gr) { [this, device, cp, done](const Status& s, const GroupRec* gr) {
if (s.ok()) { if (s.ok()) {
CompleteInstanceLocal(device, gr, cp, cp->is_source, done); CompleteInstanceLocal(device.name(), gr, cp, cp->is_source, done);
} else { } else {
done(s); done(s);
} }

View File

@ -19,9 +19,11 @@ limitations under the License.
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
@ -45,7 +47,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
~CollectiveParamResolverLocal() override {} ~CollectiveParamResolverLocal() override {}
void CompleteParamsAsync(const string& device, CollectiveParams* cp, void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr, CancellationManager* cancel_mgr,
const StatusCallback& done) override; const StatusCallback& done) override;
@ -70,10 +72,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
CollGroupParams group; CollGroupParams group;
mutable mutex mu; mutable mutex mu;
Status status TF_GUARDED_BY(mu); Status status TF_GUARDED_BY(mu);
std::set<string> device_set TF_GUARDED_BY(mu); std::unordered_map<string, DeviceAttributes> devices TF_GUARDED_BY(mu);
std::vector<string> device_list TF_GUARDED_BY(mu);
std::set<string> task_set TF_GUARDED_BY(mu);
std::vector<string> task_list TF_GUARDED_BY(mu);
std::vector<StatusCallback> waiting TF_GUARDED_BY(mu); std::vector<StatusCallback> waiting TF_GUARDED_BY(mu);
}; };
@ -85,7 +84,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
// callback. // callback.
typedef std::function<void(const Status& s, const GroupRec* gr)> typedef std::function<void(const Status& s, const GroupRec* gr)>
GroupRecCallback; GroupRecCallback;
void CompleteGroupLocal(const string& device, CollectiveParams* cp, void CompleteGroupLocal(const DeviceAttributes& device, CollectiveParams* cp,
const GroupRecCallback& done) const GroupRecCallback& done)
TF_LOCKS_EXCLUDED(group_mu_); TF_LOCKS_EXCLUDED(group_mu_);

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_resolver_local.h" #include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
@ -86,6 +87,12 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
} }
} }
DeviceAttributes GetDeviceAttributes(const string& device_name) {
Device* device = nullptr;
TF_CHECK_OK(device_mgr_->LookupDevice(device_name, &device));
return device->attributes();
}
string task_name_; string task_name_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_; std::unique_ptr<DeviceResolverLocal> drl_;
@ -187,12 +194,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
cp->instance.impl_details.subdiv_offsets.push_back(0); cp->instance.impl_details.subdiv_offsets.push_back(0);
cp->is_source = false; cp->is_source = false;
Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() { Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
prl_->CompleteParamsAsync(cp->instance.device_names[0], cp, prl_->CompleteParamsAsync(
nullptr /*CancellationManager*/, GetDeviceAttributes(cp->instance.device_names[0]), cp,
[&statuses, &note, i](const Status& s) { nullptr /*CancellationManager*/,
statuses[i] = s; [&statuses, &note, i](const Status& s) {
note[i].Notify(); statuses[i] = s;
}); note[i].Notify();
});
}); });
} }
for (int i = 0; i < NUM_DEVS; ++i) { for (int i = 0; i < NUM_DEVS; ++i) {
@ -240,12 +248,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
CollectiveParams* cp = &cps[i]; CollectiveParams* cp = &cps[i];
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp); InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp);
Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() { Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
prl_->CompleteParamsAsync(cp->instance.device_names[0], cp, prl_->CompleteParamsAsync(
nullptr /*CancellationManager*/, GetDeviceAttributes(cp->instance.device_names[0]), cp,
[&statuses, &note, i](const Status& s) { nullptr /*CancellationManager*/,
statuses[i] = s; [&statuses, &note, i](const Status& s) {
note[i].Notify(); statuses[i] = s;
}); note[i].Notify();
});
}); });
} }
for (int i = 0; i < NUM_DEVS; ++i) { for (int i = 0; i < NUM_DEVS; ++i) {
@ -278,12 +287,13 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) {
CollectiveParams* cp = &cps[i]; CollectiveParams* cp = &cps[i];
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp); InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp);
Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() { Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
prl_->CompleteParamsAsync(cp->instance.device_names[0], cp, prl_->CompleteParamsAsync(
nullptr /*CancellationManager*/, GetDeviceAttributes(cp->instance.device_names[0]), cp,
[&statuses, &note, i](const Status& s) { nullptr /*CancellationManager*/,
statuses[i] = s; [&statuses, &note, i](const Status& s) {
note[i].Notify(); statuses[i] = s;
}); note[i].Notify();
});
}); });
} }
for (int i = 0; i < NUM_DEVS; ++i) { for (int i = 0; i < NUM_DEVS; ++i) {
@ -326,8 +336,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100, cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100,
/*is_source*/ i == 0); /*is_source*/ i == 0);
prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr, prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
[&done](const Status& s) { &cancel_mgr, [&done](const Status& s) {
EXPECT_EQ(s.code(), error::ABORTED); EXPECT_EQ(s.code(), error::ABORTED);
EXPECT_EQ(s.error_message(), "__aborted__"); EXPECT_EQ(s.error_message(), "__aborted__");
done.DecrementCount(); done.DecrementCount();
@ -355,8 +365,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(group_key, instance_key, cp[i] = MakeCollectiveParams(group_key, instance_key,
/*is_source*/ i == 0); /*is_source*/ i == 0);
prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr, prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
[&done](const Status& s) { &cancel_mgr, [&done](const Status& s) {
EXPECT_EQ(s.code(), error::OK); EXPECT_EQ(s.code(), error::OK);
done.DecrementCount(); done.DecrementCount();
}); });
@ -373,12 +383,13 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(group_key, instance_key + 1, cp[i] = MakeCollectiveParams(group_key, instance_key + 1,
/*is_source*/ i == 0); /*is_source*/ i == 0);
prl_->CompleteParamsAsync( prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
device, &cp[i], &cancel_mgr, [&done](const Status& s) { &cancel_mgr, [&done](const Status& s) {
EXPECT_EQ(s.code(), error::ABORTED); EXPECT_EQ(s.code(), error::ABORTED);
EXPECT_EQ(s.error_message(), "__aborted__"); EXPECT_EQ(s.error_message(),
done.DecrementCount(); "__aborted__");
}); done.DecrementCount();
});
start.DecrementCount(); start.DecrementCount();
}); });
} }
@ -402,8 +413,8 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(group_key, instance_key, cp[i] = MakeCollectiveParams(group_key, instance_key,
/*is_source*/ i == 0); /*is_source*/ i == 0);
prl_->CompleteParamsAsync(device, &cp[i], &cancel_mgr, prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
[&done](const Status& s) { &cancel_mgr, [&done](const Status& s) {
EXPECT_EQ(s.code(), error::OK); EXPECT_EQ(s.code(), error::OK);
done.DecrementCount(); done.DecrementCount();
}); });
@ -418,7 +429,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
Notification done; Notification done;
auto cp = MakeCollectiveParams(group_key, instance_key, auto cp = MakeCollectiveParams(group_key, instance_key,
/*is_source*/ true); /*is_source*/ true);
prl_->CompleteParamsAsync(device, &cp, &cancel_mgr, prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, &cancel_mgr,
[&done](const Status& s) { [&done](const Status& s) {
EXPECT_EQ(s.code(), error::ABORTED); EXPECT_EQ(s.code(), error::ABORTED);
EXPECT_EQ(s.error_message(), "__aborted__"); EXPECT_EQ(s.error_message(), "__aborted__");
@ -457,7 +468,8 @@ TEST_F(CollectiveParamResolverLocalTest, AbortNormalCompleteParamsAsync) {
auto cp = auto cp =
MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key, MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key,
/*is_source*/ i == 0); /*is_source*/ i == 0);
prl_->CompleteParamsAsync(device, &cp, &cancel_mgr, prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp,
&cancel_mgr,
[&status, &n](const Status& s) { [&status, &n](const Status& s) {
status = s; status = s;
n.Notify(); n.Notify();

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_
#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatmap.h"
namespace tensorflow { namespace tensorflow {
@ -35,7 +36,7 @@ class TestCollectiveExecutor : public CollectiveExecutor {
}; };
class TestParamResolver : public ParamResolverInterface { class TestParamResolver : public ParamResolverInterface {
void CompleteParamsAsync(const string& device, CollectiveParams* cp, void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr, CancellationManager* cancel_mgr,
const StatusCallback& done) override { const StatusCallback& done) override {
done(errors::Internal("Unimplemented")); done(errors::Internal("Unimplemented"));

View File

@ -571,7 +571,9 @@ cc_library(
":device_resolver_distributed", ":device_resolver_distributed",
":worker_cache", ":worker_cache",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -606,6 +608,7 @@ tf_cc_test(
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"//tensorflow/core/kernels:collective_ops", "//tensorflow/core/kernels:collective_ops",
"@com_google_absl//absl/container:flat_hash_map",
], ],
) )

View File

@ -18,14 +18,18 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/cancellable_call.h" #include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
class CompleteGroupCall : public CancellableCall { class CompleteGroupCall : public CancellableCall {
public: public:
CompleteGroupCall(const CollGroupParams& group, const string& device_name, CompleteGroupCall(const CollGroupParams& group,
const DeviceAttributes& device,
const CollectiveType& collective_type, const CollectiveType& collective_type,
CancellationManager* cancel_mgr, CancellationManager* cancel_mgr,
const string& remote_worker, WorkerCacheInterface* wc) const string& remote_worker, WorkerCacheInterface* wc)
@ -33,7 +37,7 @@ class CompleteGroupCall : public CancellableCall {
req_.set_group_key(group.group_key); req_.set_group_key(group.group_key);
req_.set_group_size(group.group_size); req_.set_group_size(group.group_size);
req_.set_device_type(group.device_type.type_string()); req_.set_device_type(group.device_type.type_string());
req_.add_device_name(device_name); *req_.mutable_device_attributes() = device;
req_.set_collective_type(collective_type); req_.set_collective_type(collective_type);
} }
~CompleteGroupCall() override {} ~CompleteGroupCall() override {}
@ -98,16 +102,16 @@ CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
} }
void CollectiveParamResolverDistributed::CompleteParamsAsync( void CollectiveParamResolverDistributed::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, const DeviceAttributes& device, CollectiveParams* cp,
const StatusCallback& done) { CancellationManager* cancel_mgr, const StatusCallback& done) {
VLOG(1) << "CompleteParams distributed " << device << " for " << cp << ": " VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
<< cp->ToString(); << ": " << cp->ToString();
CompleteGroupDistributed(device, cp, cancel_mgr, CompleteGroupDistributed(device, cp, cancel_mgr,
[this, device, cp, cancel_mgr, done]( [this, device, cp, cancel_mgr, done](
const Status& s, const GroupRec* gr) { const Status& s, const GroupRec* gr) {
if (s.ok()) { if (s.ok()) {
CompleteInstanceDistributed(device, gr, cp, CompleteInstanceDistributed(
cancel_mgr, done); device.name(), gr, cp, cancel_mgr, done);
} else { } else {
done(s); done(s);
} }
@ -117,28 +121,28 @@ void CollectiveParamResolverDistributed::CompleteParamsAsync(
void CollectiveParamResolverDistributed::CompleteGroupAsync( void CollectiveParamResolverDistributed::CompleteGroupAsync(
const CompleteGroupRequest* request, CompleteGroupResponse* response, const CompleteGroupRequest* request, CompleteGroupResponse* response,
CancellationManager* cancel_mgr, const StatusCallback& done) { CancellationManager* cancel_mgr, const StatusCallback& done) {
if (!request->has_device_attributes()) {
done(errors::Internal(
"CompleteGroupRequest device_attributes is not set. Make sure you're "
"running the same version of Tensorflow on all workers."));
return;
}
CollectiveParams cp; CollectiveParams cp;
cp.group.group_key = request->group_key(); cp.group.group_key = request->group_key();
cp.group.group_size = request->group_size(); cp.group.group_size = request->group_size();
cp.group.device_type = DeviceType(request->device_type()); cp.group.device_type = DeviceType(request->device_type());
for (const string& dn : request->device_name()) {
cp.instance.device_names.push_back(dn);
}
cp.instance.type = CollectiveType(request->collective_type()); cp.instance.type = CollectiveType(request->collective_type());
CompleteGroupDistributed( CompleteGroupDistributed(
cp.instance.device_names[0], &cp, cancel_mgr, request->device_attributes(), &cp, cancel_mgr,
[response, done](const Status& s, const GroupRec* gr) { [response, done](const Status& s, const GroupRec* gr) {
if (s.ok()) { if (s.ok()) {
mutex_lock l(gr->mu); mutex_lock l(gr->mu);
response->set_group_key(gr->group.group_key); response->set_group_key(gr->group.group_key);
response->set_group_size(gr->group.group_size); response->set_group_size(gr->group.group_size);
response->set_device_type(gr->group.device_type.type_string()); response->set_device_type(gr->group.device_type.type_string());
response->set_num_tasks(gr->task_set.size()); response->set_num_tasks(gr->group.num_tasks);
for (const string& dn : gr->device_list) { for (const auto& item : gr->devices) {
response->add_device_name(dn); *response->add_device_attributes() = item.second;
}
for (const string& tn : gr->task_list) {
response->add_task_name(tn);
} }
response->set_communicator_key( response->set_communicator_key(
gr->group.runtime_details.communicator_key); gr->group.runtime_details.communicator_key);
@ -152,6 +156,22 @@ void CollectiveParamResolverDistributed::CompleteGroupAsync(
void CollectiveParamResolverDistributed::CompleteInstanceAsync( void CollectiveParamResolverDistributed::CompleteInstanceAsync(
const CompleteInstanceRequest* request, CompleteInstanceResponse* response, const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
CancellationManager* cancel_mgr, const StatusCallback& done) { CancellationManager* cancel_mgr, const StatusCallback& done) {
GroupRec* gr = GetCachedGroup(request->group_key());
if (gr == nullptr) {
done(errors::FailedPrecondition(
"group ", request->group_key(),
" not found. This normally means the server has restarted"));
return;
}
{
mutex_lock l(gr->mu);
if (!gr->status.ok() || gr->devices.size() != gr->group.group_size) {
done(errors::FailedPrecondition(
"group ", request->group_key(),
" failed to resolve. This normally means the server has restarted"));
return;
}
}
CollectiveParams* cp = new CollectiveParams; CollectiveParams* cp = new CollectiveParams;
cp->name = request->name(); cp->name = request->name();
cp->group.group_key = request->group_key(); cp->group.group_key = request->group_key();
@ -164,56 +184,44 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync(
for (int32 offset : request->subdiv_offset()) { for (int32 offset : request->subdiv_offset()) {
cp->instance.impl_details.subdiv_offsets.push_back(offset); cp->instance.impl_details.subdiv_offsets.push_back(offset);
} }
string* device = new string(request->device()); StatusCallback done_and_cleanup = [cp, done](const Status& s) {
VLOG(1) << "New cp " << cp << " for device " << *device << " : "
<< cp->ToString();
StatusCallback done_and_cleanup = [cp, device, done](const Status& s) {
done(s); done(s);
delete cp; delete cp;
delete device;
}; };
// Start by completing the group. CompleteInstanceDistributed(
CompleteGroupDistributed( request->device(), gr, cp, cancel_mgr,
*device, cp, cancel_mgr, [this, gr, cp, response, done_and_cleanup](const Status& ci_status) {
[this, cp, device, response, cancel_mgr, done_and_cleanup]( if (ci_status.ok()) {
const Status& cg_status, const GroupRec* gr) { // Now source_rank should be known, so
if (cg_status.ok()) { // retrieve it.
// Then complete the instance. FindInstanceRec(
CompleteInstanceDistributed( gr, cp,
*device, gr, cp, cancel_mgr, [cp, response, done_and_cleanup](const Status& fi_status,
[this, gr, cp, response, InstanceRec* ir) {
done_and_cleanup](const Status& ci_status) { if (fi_status.ok()) {
if (ci_status.ok()) { mutex_lock l(ir->out_mu);
// Now source_rank should be known, so ir->WaitForOutMu(l);
// retrieve it. response->set_instance_key(cp->instance.instance_key);
FindInstanceRec( response->set_source_rank(ir->source_rank);
gr, cp, done_and_cleanup(fi_status);
[cp, response, done_and_cleanup](const Status& fi_status,
InstanceRec* ir) {
if (fi_status.ok()) {
mutex_lock l(ir->out_mu);
ir->WaitForOutMu(l);
response->set_instance_key(cp->instance.instance_key);
response->set_source_rank(ir->source_rank);
done_and_cleanup(fi_status);
} else {
done_and_cleanup(fi_status);
}
});
} else { } else {
done_and_cleanup(ci_status); done_and_cleanup(fi_status);
} }
}); });
} else { } else {
done_and_cleanup(cg_status); done_and_cleanup(ci_status);
} }
}); });
} }
bool CollectiveParamResolverDistributed::GroupIsCached(int32 group_key) { CollectiveParamResolverDistributed::GroupRec*
CollectiveParamResolverDistributed::GetCachedGroup(int32 group_key) {
mutex_lock l(group_mu_); mutex_lock l(group_mu_);
const auto& it = group_table_.find(group_key); auto it = group_table_.find(group_key);
return it != group_table_.end(); if (it == group_table_.end()) {
return nullptr;
}
return it->second.get();
} }
Status CollectiveParamResolverDistributed::UpdateGroupCache( Status CollectiveParamResolverDistributed::UpdateGroupCache(
@ -226,26 +234,19 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache(
gr->group.group_key = resp.group_key(); gr->group.group_key = resp.group_key();
gr->group.group_size = resp.group_size(); gr->group.group_size = resp.group_size();
gr->group.num_tasks = resp.num_tasks(); gr->group.num_tasks = resp.num_tasks();
if (resp.device_name_size() != gr->group.group_size) { if (resp.device_attributes().empty()) {
return errors::Internal(
"CompleteGroupResponse device_attributes is empty. Make sure you're "
"running the same version of Tensorflow on all workers.");
}
if (resp.device_attributes_size() != gr->group.group_size) {
return errors::Internal( return errors::Internal(
"CompleteGroupResponse group_size doesn't match device_name list"); "CompleteGroupResponse group_size doesn't match device_name list");
} }
for (const string& dn : resp.device_name()) { for (const DeviceAttributes& device : resp.device_attributes()) {
gr->device_set.insert(dn); gr->devices[device.name()] = device;
gr->device_list.push_back(dn);
} }
if (resp.task_name_size() != gr->group.group_size) {
return errors::Internal(
"CompleteGroupResponse group_size doesn't match task_name list");
}
for (const string& tn : resp.task_name()) {
gr->task_list.push_back(tn);
gr->task_set.insert(tn);
}
CHECK_EQ(gr->task_set.size(), gr->group.num_tasks);
gr->group.runtime_details.communicator_key = resp.communicator_key(); gr->group.runtime_details.communicator_key = resp.communicator_key();
VLOG(2) << "Group communicator_key="
<< absl::CEscape(gr->group.runtime_details.communicator_key);
} }
{ {
// Group membership should never change. Once a record is in group_table_ // Group membership should never change. Once a record is in group_table_
@ -273,14 +274,15 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache(
} }
void CollectiveParamResolverDistributed::CompleteGroupDistributed( void CollectiveParamResolverDistributed::CompleteGroupDistributed(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, const DeviceAttributes& device, CollectiveParams* cp,
const GroupRecCallback& done) { CancellationManager* cancel_mgr, const GroupRecCallback& done) {
VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
<< " dev: " << device << " is_leader=" << (group_leader_.empty()); << " dev: " << device.name()
<< " is_leader=" << (group_leader_.empty());
if (group_leader_.empty()) { if (group_leader_.empty()) {
// This is the group leader, so resolution is local. // This is the group leader, so resolution is local.
return CompleteGroupLocal(device, cp, done); return CompleteGroupLocal(device, cp, done);
} else if (!GroupIsCached(cp->group.group_key)) { } else if (GetCachedGroup(cp->group.group_key) == nullptr) {
// Need to update Group cache from the leader. // Need to update Group cache from the leader.
CompleteGroupCall* call = CompleteGroupCall* call =
new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr, new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr,

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h" #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow { namespace tensorflow {
class ConfigProto; class ConfigProto;
@ -31,7 +32,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
WorkerCacheInterface* worker_cache, WorkerCacheInterface* worker_cache,
const string& task_name); const string& task_name);
void CompleteParamsAsync(const string& device, CollectiveParams* cp, void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr, CancellationManager* cancel_mgr,
const StatusCallback& done) override; const StatusCallback& done) override;
@ -46,9 +47,9 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
const StatusCallback& done) override; const StatusCallback& done) override;
protected: protected:
// Returns true iff there's an entry for this group_key in the // Returns the cached group iff there's an entry for this group_key in the
// local group_table_. // local group_table_; returns nullptr otherwise.
bool GroupIsCached(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_); GroupRec* GetCachedGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_);
// Updates group_table_ with contents of resp. // Updates group_table_ with contents of resp.
Status UpdateGroupCache(const CompleteGroupResponse& resp) Status UpdateGroupCache(const CompleteGroupResponse& resp)
@ -59,7 +60,8 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
// //
// Semantics are like those of CompleteGroupLocal but will make a // Semantics are like those of CompleteGroupLocal but will make a
// remote call to the group leader if necessary. // remote call to the group leader if necessary.
void CompleteGroupDistributed(const string& device, CollectiveParams* cp, void CompleteGroupDistributed(const DeviceAttributes& device,
CollectiveParams* cp,
CancellationManager* cancel_mgr, CancellationManager* cancel_mgr,
const GroupRecCallback& done); const GroupRecCallback& done);

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/test_utils.h" #include "tensorflow/core/distributed_runtime/test_utils.h"
@ -23,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
@ -41,6 +43,7 @@ static std::unique_ptr<Device> NewDevice(const string& type,
attr.set_name(name); attr.set_name(name);
attr.set_device_type(type); attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(3); // a non-default value attr.mutable_locality()->set_numa_node(3); // a non-default value
attr.set_incarnation(random::New64());
return absl::make_unique<FakeDevice>(attr); return absl::make_unique<FakeDevice>(attr);
} }
@ -125,127 +128,110 @@ class FakeCache : public TestWorkerCache {
class DeviceResDistTest : public ::testing::Test { class DeviceResDistTest : public ::testing::Test {
protected: protected:
DeviceResDistTest() {}
~DeviceResDistTest() override {
for (DeviceMgr* dm : device_mgrs_) {
delete dm;
}
for (auto it : dev_resolvers_) {
delete it.second;
}
for (auto it : cp_resolvers_) {
delete it.second;
}
for (FakeWorker* w : workers_) {
delete w;
}
}
void DefineWorkers(int num_workers, int num_devices, void DefineWorkers(int num_workers, int num_devices,
const string& device_type, bool nccl) { const string& device_type, bool nccl) {
ConfigProto config;
for (int w = 0; w < num_workers; ++w) { for (int w = 0; w < num_workers; ++w) {
string name = strings::StrCat("/job:worker/replica:0/task:", w); string name = strings::StrCat("/job:worker/replica:0/task:", w);
if (w == 0) { DefineWorker(name, device_type, num_devices, nccl);
config.mutable_experimental()->set_collective_group_leader(name);
if (nccl) {
config.mutable_experimental()->set_collective_nccl(true);
}
}
DefineWorker(config, name, device_type, num_devices);
} }
} }
void DefineWorker(const ConfigProto& config, const string& worker_name, void DefineWorker(const string& worker_name, const string& device_type,
const string& device_type, int num_devices) { int num_devices, bool nccl) {
ConfigProto config;
config.mutable_experimental()->set_collective_group_leader(
"/job:worker/replica:0/task:0");
config.mutable_experimental()->set_collective_nccl(nccl);
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) { for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice( devices.push_back(NewDevice(
device_type, device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i))); strings::StrCat(worker_name, "/device:", device_type, ":", i)));
} }
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices)); device_mgrs_[worker_name] =
device_mgrs_.push_back(dev_mgr); absl::make_unique<StaticDeviceMgr>(std::move(devices));
std::vector<string>* dv = &dev_by_task_[worker_name]; std::vector<string>* dv = &dev_by_task_[worker_name];
for (auto* d : dev_mgr->ListDevices()) { dv->clear();
for (auto* d : device_mgrs_[worker_name]->ListDevices()) {
dv->push_back(d->name()); dv->push_back(d->name());
} }
DeviceResolverDistributed* dev_res = dev_resolvers_[worker_name] = absl::make_unique<DeviceResolverDistributed>(
new DeviceResolverDistributed(dev_mgr, &wc_, worker_name); device_mgrs_[worker_name].get(), &wc_, worker_name);
dev_resolvers_[worker_name] = dev_res; cp_resolvers_[worker_name] =
CollectiveParamResolverDistributed* cp_res = absl::make_unique<CollectiveParamResolverDistributed>(
new CollectiveParamResolverDistributed(config, dev_mgr, dev_res, &wc_, config, device_mgrs_[worker_name].get(),
worker_name); dev_resolvers_[worker_name].get(), &wc_, worker_name);
cp_resolvers_[worker_name] = cp_res; workers_[worker_name] = absl::make_unique<FakeWorker>(
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, cp_res); worker_name, device_mgrs_[worker_name].get(),
workers_.push_back(fw); cp_resolvers_[worker_name].get());
wc_.AddWorker(worker_name, fw); wc_.AddWorker(worker_name, workers_[worker_name].get());
} }
void DefineCollectiveParams(int num_workers, int num_devices) { void DefineCollectiveParams(int num_workers, int num_devices,
const int kGroupKey = 5; const string& device_type) {
const int kInstanceKey = 3;
for (int wi = 0; wi < num_workers; ++wi) { for (int wi = 0; wi < num_workers; ++wi) {
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
for (int di = 0; di < num_devices; ++di) { for (int di = 0; di < num_devices; ++di) {
string device_name = strings::StrCat(task_name, "/device:CPU:", di); string device_name =
cp_.push_back(CollectiveParams()); strings::StrCat(task_name, "/device:", device_type, ":", di);
CollectiveParams& cp = cp_.back(); cp_[device_name] =
cp.group.group_key = kGroupKey; CreateCollectiveParams(num_workers, num_devices, device_type);
cp.group.group_size = num_workers * num_devices;
cp.group.device_type = DEVICE_CPU;
cp.group.num_tasks = num_workers;
cp.instance.instance_key = kInstanceKey;
cp.instance.type = REDUCTION_COLLECTIVE;
cp.instance.data_type = DT_FLOAT;
cp.instance.shape = TensorShape({64});
cp.instance.impl_details.subdiv_offsets.push_back(0);
} }
} }
} }
CollectiveParams CreateCollectiveParams(int num_workers, int num_devices,
const string& device_type) {
const int kGroupKey = 5;
const int kInstanceKey = 3;
CollectiveParams cp;
cp.group.group_key = kGroupKey;
cp.group.group_size = num_workers * num_devices;
cp.group.device_type = DeviceType(device_type);
cp.group.num_tasks = num_workers;
cp.instance.instance_key = kInstanceKey;
cp.instance.type = REDUCTION_COLLECTIVE;
cp.instance.data_type = DT_FLOAT;
cp.instance.shape = TensorShape({64});
cp.instance.impl_details.subdiv_offsets.push_back(0);
return cp;
}
void IssueRequests(int num_workers, int num_devices) { void IssueRequests(int num_workers, int num_devices) {
const int device_count = num_workers * num_devices;
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
num_done_ = 0; num_done_ = 0;
} }
cp_.resize(device_count); int group_size = num_workers * num_devices;
status_.resize(device_count);
int idx = 0;
for (int wi = 0; wi < num_workers; ++wi) { for (int wi = 0; wi < num_workers; ++wi) {
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
for (int di = 0; di < num_devices; ++di) { for (int di = 0; di < num_devices; ++di) {
IssueRequest(num_workers, num_devices, idx); string device_name = strings::StrCat(task_name, "/device:CPU:", di);
++idx; IssueRequest(task_name, device_name, group_size);
} }
} }
} }
void IssueRequest(int num_workers, int num_devices, int idx) { void IssueRequest(const string& task_name, const string& device_name,
int device_count = num_workers * num_devices; int group_size) {
int wi = idx / num_devices; Device* device = nullptr;
int di = idx % num_devices; TF_CHECK_OK(device_mgrs_[task_name]->LookupDevice(device_name, &device));
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); CollectiveParams* cp = &cp_[device_name];
string device_name = strings::StrCat(task_name, "/device:CPU:", di); CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name].get();
while (idx >= cp_.size()) {
status_.resize(idx + 1);
cp_.resize(idx + 1);
}
CollectiveParams* cp = &cp_[idx];
CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name];
CHECK(cp_res); CHECK(cp_res);
cp_res->CompleteParamsAsync(device_name, cp, &cm_, cp_res->CompleteParamsAsync(
[this, idx, device_count](const Status& s) { device->attributes(), cp, &cm_,
status_[idx] = s; [this, device_name, group_size](const Status& s) {
{ status_[device_name] = s;
mutex_lock l(mu_); {
++num_done_; mutex_lock l(mu_);
if (num_done_ == device_count) { ++num_done_;
done_.notify_all(); if (num_done_ == group_size) {
} done_.notify_all();
} }
}); }
});
} }
void ValidateCollectiveParams(int num_workers, int num_devices) { void ValidateCollectiveParams(int num_workers, int num_devices) {
@ -259,39 +245,59 @@ class DeviceResDistTest : public ::testing::Test {
// Verify that all cp_ values get the same set of task and device // Verify that all cp_ values get the same set of task and device
// names, with unique default_rank in the expected order. // names, with unique default_rank in the expected order.
const int dev_count = num_workers * num_devices; const int dev_count = num_workers * num_devices;
string dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
for (int wi = 0; wi < num_workers; ++wi) { for (int wi = 0; wi < num_workers; ++wi) {
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
for (int di = 0; di < num_devices; ++di) { for (int di = 0; di < num_devices; ++di) {
string device_name = strings::StrCat(task_name, "/device:CPU:", di); string device_name = strings::StrCat(task_name, "/device:CPU:", di);
int idx = wi * num_devices + di; int idx = wi * num_devices + di;
TF_ASSERT_OK(status_[idx]); TF_ASSERT_OK(status_[device_name]);
EXPECT_EQ(cp_[idx].default_rank, idx); EXPECT_EQ(cp_[device_name].default_rank, idx);
EXPECT_EQ(cp_[idx].instance.device_names.size(), dev_count); EXPECT_EQ(cp_[device_name].instance.device_names.size(), dev_count);
EXPECT_EQ(cp_[idx].instance.device_names[idx], device_name); EXPECT_EQ(cp_[device_name].instance.device_names[idx], device_name);
EXPECT_EQ(cp_[idx].instance.task_names[idx], task_name); EXPECT_EQ(cp_[device_name].instance.task_names[idx], task_name);
if (idx > 0) { if (idx > 0) {
EXPECT_EQ(cp_[0].group.runtime_details.communicator_key, EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key,
cp_[idx].group.runtime_details.communicator_key); cp_[device_name].group.runtime_details.communicator_key);
for (int i = 0; i < dev_count; ++i) { for (int i = 0; i < dev_count; ++i) {
EXPECT_EQ(cp_[0].instance.device_names[i], EXPECT_EQ(cp_[dev0].instance.device_names[i],
cp_[idx].instance.device_names[i]); cp_[device_name].instance.device_names[i]);
EXPECT_EQ(cp_[0].instance.task_names[i], EXPECT_EQ(cp_[dev0].instance.task_names[i],
cp_[idx].instance.task_names[i]); cp_[device_name].instance.task_names[i]);
} }
} }
} }
} }
} }
void RestartWorker(int worker_idx, int num_workers, int num_devices,
const string& device_type, bool nccl) {
string worker_name =
strings::StrCat("/job:worker/replica:0/task:", worker_idx);
DefineWorker(worker_name, device_type, num_devices, nccl);
for (int i = 0; i < num_devices; ++i) {
string device_name =
strings::StrCat(worker_name, "/device:", device_type, ":", i);
cp_[device_name] =
CreateCollectiveParams(num_workers, num_devices, device_type);
status_.erase(device_name);
}
}
FakeCache wc_; FakeCache wc_;
CancellationManager cm_; CancellationManager cm_;
std::vector<DeviceMgr*> device_mgrs_; // Below are keyed by task names.
std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_; absl::flat_hash_map<string, std::unique_ptr<DeviceMgr>> device_mgrs_;
std::unordered_map<string, CollectiveParamResolverDistributed*> cp_resolvers_; absl::flat_hash_map<string, std::unique_ptr<DeviceResolverDistributed>>
std::unordered_map<string, std::vector<string>> dev_by_task_; dev_resolvers_;
std::vector<FakeWorker*> workers_; absl::flat_hash_map<string,
std::vector<CollectiveParams> cp_; std::unique_ptr<CollectiveParamResolverDistributed>>
std::vector<Status> status_; cp_resolvers_;
absl::flat_hash_map<string, std::vector<string>> dev_by_task_;
absl::flat_hash_map<string, std::unique_ptr<FakeWorker>> workers_;
// Below are keyed by device names;
absl::flat_hash_map<string, CollectiveParams> cp_;
absl::flat_hash_map<string, Status> status_;
mutex mu_; mutex mu_;
int num_done_ TF_GUARDED_BY(mu_); int num_done_ TF_GUARDED_BY(mu_);
condition_variable done_; condition_variable done_;
@ -300,8 +306,8 @@ class DeviceResDistTest : public ::testing::Test {
TEST_F(DeviceResDistTest, Workers1Devices1) { TEST_F(DeviceResDistTest, Workers1Devices1) {
const int num_workers = 1; const int num_workers = 1;
const int num_devices = 1; const int num_devices = 1;
DefineWorkers(num_workers, num_devices, "CPU", false); DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
DefineCollectiveParams(num_workers, num_devices); DefineCollectiveParams(num_workers, num_devices, "CPU");
IssueRequests(num_workers, num_devices); IssueRequests(num_workers, num_devices);
ValidateCollectiveParams(num_workers, num_devices); ValidateCollectiveParams(num_workers, num_devices);
} }
@ -309,12 +315,25 @@ TEST_F(DeviceResDistTest, Workers1Devices1) {
TEST_F(DeviceResDistTest, Workers2Devices2) { TEST_F(DeviceResDistTest, Workers2Devices2) {
const int num_workers = 2; const int num_workers = 2;
const int num_devices = 2; const int num_devices = 2;
DefineWorkers(num_workers, num_devices, "CPU", false); DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
DefineCollectiveParams(num_workers, num_devices); DefineCollectiveParams(num_workers, num_devices, "CPU");
IssueRequests(num_workers, num_devices); IssueRequests(num_workers, num_devices);
ValidateCollectiveParams(num_workers, num_devices); ValidateCollectiveParams(num_workers, num_devices);
} }
TEST_F(DeviceResDistTest, DifferentIncarnation) {
const int num_workers = 2;
const int num_devices = 1;
DefineWorkers(num_workers, num_devices, "CPU", /*nccl*/ false);
DefineCollectiveParams(num_workers, num_devices, "CPU");
IssueRequests(num_workers, num_devices);
RestartWorker(1, num_workers, num_devices, "CPU", /*nccl*/ false);
const string task_name = "/job:worker/replica:0/task:1";
const string device_name = absl::StrCat(task_name, "/device:CPU:0");
IssueRequest(task_name, device_name, num_workers * num_devices);
EXPECT_TRUE(errors::IsFailedPrecondition(status_[device_name]));
}
#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
namespace { namespace {
// A mock NcclReducer for testing group runtime details initialization with CPU // A mock NcclReducer for testing group runtime details initialization with CPU
@ -347,7 +366,7 @@ TEST_F(DeviceResDistTest, Workers4Devices3) {
const int num_workers = 4; const int num_workers = 4;
const int num_devices = 3; const int num_devices = 3;
DefineWorkers(num_workers, num_devices, "CPU", true); DefineWorkers(num_workers, num_devices, "CPU", true);
DefineCollectiveParams(num_workers, num_devices); DefineCollectiveParams(num_workers, num_devices, "CPU");
IssueRequests(num_workers, num_devices); IssueRequests(num_workers, num_devices);
ValidateCollectiveParams(num_workers, num_devices); ValidateCollectiveParams(num_workers, num_devices);
} }

View File

@ -180,7 +180,8 @@ class ParamResolverInterface {
// Called by each collective op at first execution in order to fill out // Called by each collective op at first execution in order to fill out
// the CollectiveParams structure with data gathered from the full // the CollectiveParams structure with data gathered from the full
// (maybe distributed) collection of peer nodes. // (maybe distributed) collection of peer nodes.
virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp, virtual void CompleteParamsAsync(const DeviceAttributes& device,
CollectiveParams* cp,
CancellationManager* cancel_mgr, CancellationManager* cancel_mgr,
const StatusCallback& done) = 0; const StatusCallback& done) = 0;
@ -301,7 +302,8 @@ class CollectiveExecutor : public core::RefCounted {
"a CollectiveExecutor has not been provided.")); "a CollectiveExecutor has not been provided."));
} }
virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp, virtual void CompleteParamsAsync(const DeviceAttributes& device,
CollectiveParams* cp,
CancellationManager* cancel_mgr, CancellationManager* cancel_mgr,
StatusCallback done) { StatusCallback done) {
done(errors::Internal( done(errors::Internal(

View File

@ -73,7 +73,7 @@ class CollectiveOpKernel : public AsyncOpKernel {
<< " group " << col_params_.group.group_key << " instance " << " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key; << col_params_.instance.instance_key;
col_exec->CompleteParamsAsync( col_exec->CompleteParamsAsync(
c->device()->name(), &col_params_, c->cancellation_manager(), c->device()->attributes(), &col_params_, c->cancellation_manager(),
[this, c, done](const Status& s) { [this, c, done](const Status& s) {
if (s.ok()) { if (s.ok()) {
col_params_.instance.impl_details.dependencies = dependencies_; col_params_.instance.impl_details.dependencies = dependencies_;
@ -538,7 +538,8 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
<< " group " << col_params->group.group_key << " instance " << " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key; << col_params->instance.instance_key;
col_exec->CompleteParamsAsync( col_exec->CompleteParamsAsync(
c->device()->name(), col_params.get(), c->cancellation_manager(), c->device()->attributes(), col_params.get(),
c->cancellation_manager(),
[c, done = std::move(done), col_params, col_exec](const Status& s) { [c, done = std::move(done), col_params, col_exec](const Status& s) {
if (s.ok()) { if (s.ok()) {
auto actual_done = [c, group_key = col_params->group.group_key, auto actual_done = [c, group_key = col_params->group.group_key,

View File

@ -545,8 +545,10 @@ message CompleteGroupRequest {
int32 group_key = 1; int32 group_key = 1;
int32 group_size = 2; int32 group_size = 2;
string device_type = 3; string device_type = 3;
repeated string device_name = 4;
int32 collective_type = 5; int32 collective_type = 5;
DeviceAttributes device_attributes = 6;
reserved 4;
} }
// Gives the complete membership of the group identified by group_key. // Gives the complete membership of the group identified by group_key.
@ -555,9 +557,10 @@ message CompleteGroupResponse {
int32 group_size = 2; int32 group_size = 2;
string device_type = 3; string device_type = 3;
int32 num_tasks = 4; // number of distinct tasks hosting the devices int32 num_tasks = 4; // number of distinct tasks hosting the devices
repeated string device_name = 5;
repeated string task_name = 6; // task name prefixes of device_names
bytes communicator_key = 7; bytes communicator_key = 7;
repeated DeviceAttributes device_attributes = 8;
reserved 5, 6;
} }
// Supplies data about one collective op belonging to the instance identified // Supplies data about one collective op belonging to the instance identified