Ensure that CollectiveParams outlives all references to it.

Before this change, it was possible to access a `const CollectiveParams&` after
it was destroyed.  For example, the call to `UnblockDependencies` in
`NcclCommunicator::Enqueue` raced with the done_callback of the collective
participant.

This change makes `CollectiveParams` a refcounted object, and holds references
everywhere it may be accessed.

PiperOrigin-RevId: 355646163
Change-Id: I7fd164afe8c1c9aa1c3b77a988930a0624977c7c
This commit is contained in:
Ayush Dubey 2021-02-04 09:40:21 -08:00 committed by TensorFlower Gardener
parent e743fcee33
commit 18eaf4e8f1
20 changed files with 627 additions and 544 deletions

View File

@ -264,7 +264,7 @@ Status BaseCollectiveExecutor::GetStatus(const Status& s) {
}
void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
const CollectiveParams& col_params,
const CollectiveParams* col_params,
const string& exec_key,
StatusCallback done) {
// See CompleteParamsAsync() how done() and the timeout callback interacts.
@ -281,7 +281,7 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
}
};
auto timeout_microseconds = static_cast<int64>(
col_params.instance.impl_details.timeout_seconds * 1'000'000);
col_params->instance.impl_details.timeout_seconds * 1'000'000);
if (timeout_microseconds > 0) {
// TODO(xldrx): Share the timeout watchdog thread among collectives.
SchedNonBlockingClosureAfter(
@ -297,15 +297,15 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
}
Tensor* output = ctx->mutable_output(0);
const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
col_params.instance.type == GATHER_COLLECTIVE ||
col_params.instance.type == PERMUTE_COLLECTIVE ||
(col_params.instance.type == BROADCAST_COLLECTIVE &&
col_params.is_source))
const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE ||
col_params->instance.type == GATHER_COLLECTIVE ||
col_params->instance.type == PERMUTE_COLLECTIVE ||
(col_params->instance.type == BROADCAST_COLLECTIVE &&
col_params->is_source))
? &ctx->input(0)
: nullptr;
CollectiveImplementationInterface* col_impl = nullptr;
Status status = CreateCollective(col_params, &col_impl);
Status status = CreateCollective(*col_params, &col_impl);
if (!status.ok()) {
done_safe(status);
DCHECK_EQ(nullptr, col_impl);

View File

@ -110,7 +110,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
void StartAbort(const Status& s) override TF_LOCKS_EXCLUDED(status_mu_);
void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params,
void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams* col_params,
const string& exec_key, StatusCallback done) override;
void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,

View File

@ -513,14 +513,14 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
void CollectiveParamResolverLocal::InitInstanceSharedParams(
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
ir->shared.instance = cp->instance;
ir->shared.default_rank = -1;
ir->shared->instance = cp->instance;
ir->shared->default_rank = -1;
// Set is_local and task_names in *shared prior to invoking
// GetDeviceAttributesAsync. In a distributed context this function can be
// called by a derived class, some of the devices may be non-local and
// GetDeviceAttributesAsync will use those fields to launch RPCs.
CompleteTaskIsLocal(task_name_, &ir->shared);
CompleteTaskIsLocal(task_name_, ir->shared);
}
// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
@ -662,11 +662,11 @@ void CollectiveParamResolverLocal::CompleteInstanceLocal(
if (!created_irec) {
// Check that the preexisting IRec is consistent with the params passed into
// this invocation.
if (ir->shared.instance.type != cp->instance.type ||
ir->shared.instance.data_type != cp->instance.data_type) {
if (ir->shared->instance.type != cp->instance.type ||
ir->shared->instance.data_type != cp->instance.data_type) {
done(errors::Internal("Collective instance ", cp->instance.instance_key,
" expected type ", ir->shared.instance.type,
" and data_type ", ir->shared.instance.data_type,
" expected type ", ir->shared->instance.type,
" and data_type ", ir->shared->instance.data_type,
" but got type ", cp->instance.type,
" and data_type ", cp->instance.data_type));
return;
@ -686,7 +686,7 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
status = ir->status;
if (status.ok()) {
// custom operator= does a deep copy.
cp->instance = ir->shared.instance;
cp->instance = ir->shared->instance;
}
}
if (!status.ok()) {

View File

@ -98,7 +98,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
struct InstanceRec {
mutex mu;
// Values to be shared by all instances, constant after initialization.
CollectiveParams shared;
CollectiveParams* shared;
// If an error occurs during initialization this structure stays in the
// table with a non-OK status. Purging the table and restarting needs to be
// done at a higher level.
@ -113,7 +113,9 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
std::vector<bool> known TF_GUARDED_BY(mu);
std::vector<IRConsumer> known_waiters TF_GUARDED_BY(mu);
InstanceRec() : source_rank(-1), known_count(0) {}
InstanceRec()
: shared(new CollectiveParams()), source_rank(-1), known_count(0) {}
~InstanceRec() { shared->Unref(); }
};
// Find the InstanceRec with the same instance_key as cp. If it doesn't

View File

@ -161,11 +161,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
}
TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
CollectiveParams cps[NUM_DEVS];
CollectiveParams* cps[NUM_DEVS];
Status statuses[NUM_DEVS];
Notification note[NUM_DEVS];
for (int i = 0; i < NUM_DEVS; ++i) {
CollectiveParams* cp = &cps[i];
cps[i] = new CollectiveParams();
CollectiveParams* cp = cps[i];
cp->group.group_key = 1;
cp->group.group_size = 3;
cp->group.device_type = DeviceType("CPU");
@ -192,17 +193,18 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
}
for (int i = 0; i < NUM_DEVS; ++i) {
TF_ASSERT_OK(statuses[i]);
ASSERT_EQ(cps[i].group.device_names.size(), 3);
ASSERT_EQ(cps[i]->group.device_names.size(), 3);
for (int j = 0; j < NUM_DEVS; ++j) {
EXPECT_EQ(
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
cps[i].group.device_names[j]);
EXPECT_TRUE(cps[i].task.is_local[j]);
cps[i]->group.device_names[j]);
EXPECT_TRUE(cps[i]->task.is_local[j]);
}
EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
EXPECT_FALSE(cps[i].is_source);
EXPECT_EQ(cps[i].default_rank, i);
EXPECT_TRUE(cps[i].group.same_num_devices_per_task);
EXPECT_EQ(cps[i]->instance.impl_details.subdiv_source_rank.size(), 0);
EXPECT_FALSE(cps[i]->is_source);
EXPECT_EQ(cps[i]->default_rank, i);
EXPECT_TRUE(cps[i]->group.same_num_devices_per_task);
cps[i]->Unref();
}
}
@ -223,11 +225,12 @@ void InitializeCollectiveParamsForBroadcast(int instance_key, int device_idx,
TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
constexpr int kInstanceKey = 5;
CollectiveParams cps[NUM_DEVS];
CollectiveParams* cps[NUM_DEVS];
Status statuses[NUM_DEVS];
Notification note[NUM_DEVS];
for (int i = 0; i < NUM_DEVS; ++i) {
CollectiveParams* cp = &cps[i];
cps[i] = new CollectiveParams();
CollectiveParams* cp = cps[i];
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp);
Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
string device =
@ -245,16 +248,17 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
}
for (int i = 0; i < NUM_DEVS; ++i) {
TF_ASSERT_OK(statuses[i]);
ASSERT_EQ(cps[i].group.device_names.size(), 3);
ASSERT_EQ(cps[i]->group.device_names.size(), 3);
for (int j = 0; j < NUM_DEVS; ++j) {
EXPECT_EQ(
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
cps[i].group.device_names[j]);
EXPECT_TRUE(cps[i].task.is_local[j]);
cps[i]->group.device_names[j]);
EXPECT_TRUE(cps[i]->task.is_local[j]);
}
EXPECT_EQ(cps[i].is_source, (i == 1));
EXPECT_EQ(cps[i].default_rank, i);
EXPECT_TRUE(cps[i].group.same_num_devices_per_task);
EXPECT_EQ(cps[i]->is_source, (i == 1));
EXPECT_EQ(cps[i]->default_rank, i);
EXPECT_TRUE(cps[i]->group.same_num_devices_per_task);
cps[i]->Unref();
}
}
@ -263,11 +267,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
// get an internal error from param resolution.
TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) {
constexpr int kInstanceKey = 8;
CollectiveParams cps[NUM_DEVS];
CollectiveParams* cps[NUM_DEVS];
Status statuses[NUM_DEVS];
Notification note[NUM_DEVS];
for (int i = 0; i < NUM_DEVS; ++i) {
CollectiveParams* cp = &cps[i];
cps[i] = new CollectiveParams();
CollectiveParams* cp = cps[i];
InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp);
Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
string device =
@ -291,27 +296,28 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) {
" found no source for broadcast. This could mean that there"
" were group_size=",
NUM_DEVS, " BcastRecvs but no BcastSend."));
cps[i]->Unref();
}
}
CollectiveParams MakeCollectiveParams(int group_key, int instance_key,
CollectiveParams* MakeCollectiveParams(int group_key, int instance_key,
bool is_source) {
CollectiveParams cp;
cp.group.group_key = group_key;
cp.group.group_size = NUM_DEVS;
cp.group.device_type = DeviceType("CPU");
cp.group.num_tasks = 1;
cp.instance.instance_key = instance_key;
auto* cp = new CollectiveParams();
cp->group.group_key = group_key;
cp->group.group_size = NUM_DEVS;
cp->group.device_type = DeviceType("CPU");
cp->group.num_tasks = 1;
cp->instance.instance_key = instance_key;
// CompleteInstanceLocal only waits for the group for broadcasts.
// Testing with broadcasts yields better coverage.
cp.instance.type = BROADCAST_COLLECTIVE;
cp.is_source = is_source;
cp->instance.type = BROADCAST_COLLECTIVE;
cp->is_source = is_source;
return cp;
}
TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
CancellationManager cancel_mgr;
std::vector<CollectiveParams> cp(NUM_DEVS - 1);
std::vector<CollectiveParams*> cp(NUM_DEVS - 1);
BlockingCounter start(NUM_DEVS - 1);
BlockingCounter done(NUM_DEVS - 1);
for (int i = 0; i < NUM_DEVS - 1; ++i) {
@ -320,11 +326,12 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100,
/*is_source*/ i == 0);
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
&cancel_mgr, [&done](const Status& s) {
prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr,
[&done, cp = cp[i]](const Status& s) {
EXPECT_EQ(s.code(), error::ABORTED);
EXPECT_EQ(s.error_message(), "__aborted__");
done.DecrementCount();
cp->Unref();
});
start.DecrementCount();
});
@ -336,7 +343,7 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
CancellationManager cancel_mgr;
std::vector<CollectiveParams> cp(NUM_DEVS);
std::vector<CollectiveParams*> cp(NUM_DEVS);
int group_key = 100;
int instance_key = 100;
// First do a normal CompleteParamsAsync to complete the group;
@ -349,10 +356,12 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(group_key, instance_key,
/*is_source*/ i == 0);
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
&cancel_mgr, [&done](const Status& s) {
prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i],
&cancel_mgr,
[&done, cp = cp[i]](const Status& s) {
EXPECT_EQ(s.code(), error::OK);
done.DecrementCount();
cp->Unref();
});
});
}
@ -361,18 +370,18 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
BlockingCounter start(NUM_DEVS - 1);
BlockingCounter done(NUM_DEVS - 1);
for (int i = 0; i < NUM_DEVS - 1; ++i) {
Env::Default()->SchedClosure(
[this, group_key, instance_key, i, &cancel_mgr, &cp, &start, &done] {
Env::Default()->SchedClosure([this, group_key, instance_key, i, &cancel_mgr,
&cp, &start, &done] {
string device =
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(group_key, instance_key + 1,
/*is_source*/ i == 0);
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
&cancel_mgr, [&done](const Status& s) {
prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr,
[&done, cp = cp[i]](const Status& s) {
EXPECT_EQ(s.code(), error::ABORTED);
EXPECT_EQ(s.error_message(),
"__aborted__");
EXPECT_EQ(s.error_message(), "__aborted__");
done.DecrementCount();
cp->Unref();
});
start.DecrementCount();
});
@ -388,7 +397,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
int instance_key = 100;
// First do a normal CompleteParamsAsync to complete the group;
{
std::vector<CollectiveParams> cp(NUM_DEVS);
std::vector<CollectiveParams*> cp(NUM_DEVS);
BlockingCounter done(NUM_DEVS);
for (int i = 0; i < NUM_DEVS; ++i) {
Env::Default()->SchedClosure([this, group_key, instance_key, i,
@ -397,10 +406,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
cp[i] = MakeCollectiveParams(group_key, instance_key,
/*is_source*/ i == 0);
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
&cancel_mgr, [&done](const Status& s) {
prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i],
&cancel_mgr,
[&done, cp = cp[i]](const Status& s) {
EXPECT_EQ(s.code(), error::OK);
done.DecrementCount();
cp->Unref();
});
});
}
@ -411,9 +422,10 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
auto complete_params = [this, &cancel_mgr](int group_key, int instance_key) {
string device = "/job:localhost/replica:0/task:0/device:CPU:0";
Notification done;
auto cp = MakeCollectiveParams(group_key, instance_key,
auto* cp = MakeCollectiveParams(group_key, instance_key,
/*is_source*/ true);
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, &cancel_mgr,
core::ScopedUnref unref(cp);
prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, &cancel_mgr,
[&done](const Status& s) {
EXPECT_EQ(s.code(), error::ABORTED);
EXPECT_EQ(s.error_message(), "__aborted__");
@ -449,16 +461,17 @@ TEST_F(CollectiveParamResolverLocalTest, AbortNormalCompleteParamsAsync) {
while (true) {
Status status;
Notification n;
auto cp =
auto* cp =
MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key,
/*is_source*/ i == 0);
prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp,
prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp,
&cancel_mgr,
[&status, &n](const Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
cp->Unref();
// The status should be either OK or the aborted status.
if (!status.ok()) {
EXPECT_EQ(status.code(), error::ABORTED);

View File

@ -188,7 +188,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) {
CHECK(col_ctx->dev_mgr);
col_ctx_ = col_ctx;
col_params_ = &col_ctx->col_params;
col_params_ = col_ctx->col_params;
return collective_util::InitializeDeviceAndLocality(
col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
&col_ctx->device_locality);

View File

@ -58,17 +58,18 @@ class TrivialTest : public ::testing::Test {
// ST = send_to rank vector
#define DEF_TL_TEST(D, S, R, RF, ST) \
TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \
CollectiveParams cp; \
cp.group.group_size = D; \
cp.instance.impl_details.subdiv_source_rank = {S}; \
cp.instance.impl_details.subdiv_permutations.push_back( \
auto* cp = new CollectiveParams(); \
core::ScopedUnref unref(cp); \
cp->group.group_size = D; \
cp->instance.impl_details.subdiv_source_rank = {S}; \
cp->instance.impl_details.subdiv_permutations.push_back( \
std::vector<int>(D, 0)); \
cp.subdiv_rank = {R}; \
cp.is_source = (S == R); \
EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(cp, 0)); \
cp->subdiv_rank = {R}; \
cp->is_source = (S == R); \
EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(*cp, 0)); \
std::vector<int> expected = ST; \
std::vector<int> send_to; \
HierarchicalTreeBroadcaster::TreeSendTo(cp, 0, &send_to); \
HierarchicalTreeBroadcaster::TreeSendTo(*cp, 0, &send_to); \
ASSERT_EQ(expected.size(), send_to.size()); \
for (int i = 0; i < expected.size(); ++i) { \
EXPECT_EQ(expected[i], send_to[i]); \
@ -196,12 +197,14 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
class HierarchicalTreeBroadcasterTest : public ::testing::Test {
protected:
HierarchicalTreeBroadcasterTest() : device_type_(DEVICE_CPU) {}
HierarchicalTreeBroadcasterTest()
: device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
~HierarchicalTreeBroadcasterTest() override {
stop_ = true;
for (auto i : instances_) delete i;
if (col_exec_) col_exec_->Unref();
if (col_params_) col_params_->Unref();
}
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -262,30 +265,31 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
dev_mgr_.get(),
gpu_ring_order_.get(), work_queue_);
col_params_.name = "test_collective";
col_params_.instance.data_type = dtype;
col_params_ = new CollectiveParams();
col_params_->name = "test_collective";
col_params_->instance.data_type = dtype;
static const int kGroupKey = 6;
col_params_.group.group_key = kGroupKey;
col_params_->group.group_key = kGroupKey;
static const int kInstanceKey = 18;
col_params_.instance.instance_key = kInstanceKey;
col_params_.group.device_type = device_type;
col_params_.group.group_size = num_workers * num_devices_per_worker;
col_params_.instance.impl_details.subdiv_offsets.clear();
col_params_.instance.type = BROADCAST_COLLECTIVE;
col_params_->instance.instance_key = kInstanceKey;
col_params_->group.device_type = device_type;
col_params_->group.group_size = num_workers * num_devices_per_worker;
col_params_->instance.impl_details.subdiv_offsets.clear();
col_params_->instance.type = BROADCAST_COLLECTIVE;
int num_subdivs = num_workers + (num_workers > 1 ? 1 : 0);
VLOG(2) << "#subdiv=" << num_subdivs;
col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
col_params_.subdiv_rank.resize(num_subdivs);
col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs);
col_params_->subdiv_rank.resize(num_subdivs);
// Inter-machine broadcast.
int subdiv_i = 0;
if (num_workers > 1) {
col_params_.instance.impl_details.subdiv_permutations[subdiv_i].resize(
col_params_->instance.impl_details.subdiv_permutations[subdiv_i].resize(
total_num_devices, -1);
for (int i = 0, rank = 0; i < total_num_devices; i++) {
if (i % num_devices_per_worker == 0) {
col_params_.instance.impl_details
col_params_->instance.impl_details
.subdiv_permutations[subdiv_i][rank] = i;
rank++;
}
@ -293,7 +297,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
if (VLOG_IS_ON(2)) {
string sp_buf;
for (int p :
col_params_.instance.impl_details.subdiv_permutations[subdiv_i])
col_params_->instance.impl_details.subdiv_permutations[subdiv_i])
strings::StrAppend(&sp_buf, p, ", ");
VLOG(2) << "subdiv_i=" << subdiv_i << " perm=" << sp_buf;
}
@ -301,22 +305,22 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
}
// Intra-machine broadcast.
for (int i = 0; subdiv_i < num_subdivs; i++, subdiv_i++) {
col_params_.instance.impl_details.subdiv_permutations[subdiv_i].resize(
col_params_->instance.impl_details.subdiv_permutations[subdiv_i].resize(
total_num_devices, -1);
int perm_i_base = i * num_devices_per_worker;
VLOG(2) << "subdiv_i=" << subdiv_i << " i=" << i
<< " perm_i_base=" << perm_i_base << " subdiv_perms.size="
<< col_params_.instance.impl_details.subdiv_permutations.size();
<< col_params_->instance.impl_details.subdiv_permutations.size();
// subdiv for worker i.
for (int j = perm_i_base, rank = 0;
j < perm_i_base + num_devices_per_worker; j++, rank++) {
col_params_.instance.impl_details.subdiv_permutations[subdiv_i][rank] =
col_params_->instance.impl_details.subdiv_permutations[subdiv_i][rank] =
j;
}
if (VLOG_IS_ON(2)) {
string sp_buf;
for (int p :
col_params_.instance.impl_details.subdiv_permutations[subdiv_i])
col_params_->instance.impl_details.subdiv_permutations[subdiv_i])
strings::StrAppend(&sp_buf, p, ", ");
VLOG(2) << "subdiv_i=" << subdiv_i << " perm=" << sp_buf;
}
@ -333,16 +337,16 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
dev_name = strings::StrCat(task_name, "/device:CPU:", di);
}
VLOG(2) << "dev=" << dev_name;
col_params_.group.device_names.push_back(dev_name);
col_params_.group.task_names.push_back(task_name);
col_params_.task.is_local.push_back(true);
col_params_->group.device_names.push_back(dev_name);
col_params_->group.task_names.push_back(task_name);
col_params_->task.is_local.push_back(true);
}
}
for (int wi = 0; wi < num_workers; wi++) {
for (int di = 0; di < num_devices_per_worker; di++) {
int default_rank = wi * num_devices_per_worker + di;
instances_.push_back(new DeviceInstance(
default_rank, col_params_.group.device_names[default_rank],
default_rank, col_params_->group.device_names[default_rank],
device_type, this));
}
}
@ -435,7 +439,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
// Copy the expected value from the broadcast source tensor
std::vector<T> expected(tensor_len, 0.0);
const CollectiveParams& cp = instances_[0]->col_params_;
const CollectiveParams& cp = *instances_[0]->col_params_;
int broadcast_dev_id =
cp.instance.impl_details.subdiv_permutations
[0][cp.instance.impl_details.subdiv_source_rank[0]];
@ -558,27 +562,29 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
: parent_(parent),
dev_name_(dev_name),
device_type_(device_type),
rank_(rank) {
rank_(rank),
col_params_(new CollectiveParams()) {
TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_));
col_params_.name = parent_->col_params_.name;
col_params_.instance.data_type = parent_->col_params_.instance.data_type;
col_params_.group = parent_->col_params_.group;
col_params_.instance.instance_key =
parent_->col_params_.instance.instance_key;
col_params_.task.is_local = parent_->col_params_.task.is_local;
col_params_.instance.impl_details.subdiv_permutations =
parent_->col_params_.instance.impl_details.subdiv_permutations;
col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
col_params_->name = parent_->col_params_->name;
col_params_->instance.data_type =
parent_->col_params_->instance.data_type;
col_params_->group = parent_->col_params_->group;
col_params_->instance.instance_key =
parent_->col_params_->instance.instance_key;
col_params_->task.is_local = parent_->col_params_->task.is_local;
col_params_->instance.impl_details.subdiv_permutations =
parent_->col_params_->instance.impl_details.subdiv_permutations;
col_params_->subdiv_rank = parent_->col_params_->subdiv_rank;
int group_size = col_params_.group.group_size;
CHECK_EQ(group_size, col_params_.group.device_names.size());
int group_size = col_params_->group.group_size;
CHECK_EQ(group_size, col_params_->group.device_names.size());
// Default rank is order in device_names.
col_params_.default_rank = rank;
col_params_->default_rank = rank;
auto& impl = col_params_.instance.impl_details;
auto& impl = col_params_->instance.impl_details;
size_t num_subdivs = impl.subdiv_permutations.size();
impl.subdiv_source_rank.resize(num_subdivs, 0);
col_params_.subdiv_rank.resize(num_subdivs);
col_params_->subdiv_rank.resize(num_subdivs);
for (size_t si = 0; si < num_subdivs; si++) {
int perm_rank = -1;
for (int i = 0; i < group_size; i++) {
@ -587,18 +593,20 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
break;
}
}
col_params_.subdiv_rank[si] = perm_rank;
col_params_->subdiv_rank[si] = perm_rank;
}
string rank_buf;
for (int r : col_params_.subdiv_rank) {
for (int r : col_params_->subdiv_rank) {
strings::StrAppend(&rank_buf, r, ", ");
}
VLOG(1) << "default=" << rank << " subdiv_ranks=" << rank_buf;
col_params_.is_source =
col_params_.subdiv_rank[0] == impl.subdiv_source_rank[0];
col_params_->is_source =
col_params_->subdiv_rank[0] == impl.subdiv_source_rank[0];
}
~DeviceInstance() { col_params_->Unref(); }
void InitTensor(DataType dtype, const TensorShape& shape,
const InitFunc& f) {
tensor_ =
@ -641,22 +649,22 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
op_params.op_device_context = dev_ctx;
int forward_from[] = {OpKernelContext::Params::kNeverForward};
if (forward_input) forward_from[0] = 0;
if (col_params_.is_source) {
if (col_params_->is_source) {
op_params.forward_from_array = &forward_from[0];
}
AllocatorAttributes generic_alloc_attr;
op_params.output_attr_array = &generic_alloc_attr;
std::unique_ptr<OpKernel> op =
col_params_.is_source
? parent_->GetCollectiveBcastSend(col_params_, &tensor_,
col_params_->is_source
? parent_->GetCollectiveBcastSend(*col_params_, &tensor_,
DEVICE_CPU, device_)
: parent_->GetCollectiveBcastRecv(col_params_, tensor_.shape(),
: parent_->GetCollectiveBcastRecv(*col_params_, tensor_.shape(),
DEVICE_CPU, device_);
op_params.op_kernel = op.get();
OpKernelContext ctx(&op_params, 1);
Tensor* output_tensor_ptr = nullptr;
if (col_params_.is_source) {
if (col_params_->is_source) {
TF_CHECK_OK(ctx.forward_input_or_allocate_output(
{0}, 0, tensor_.shape(), &output_tensor_ptr));
} else {
@ -665,11 +673,11 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
}
CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
const Tensor* input_tensor_ptr =
col_params_.is_source ? &tensor_ : nullptr;
col_params_->is_source ? &tensor_ : nullptr;
// Prepare a Broadcaster instance.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
strings::StrCat(col_params_->instance.instance_key, ":0:0");
HierarchicalTreeBroadcaster* broadcaster =
new HierarchicalTreeBroadcaster;
core::ScopedUnref unref(broadcaster);
@ -694,7 +702,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
int rank_;
Tensor tensor_;
Device* device_;
CollectiveParams col_params_;
CollectiveParams* col_params_;
Status status_;
}; // class DeviceInstance
@ -708,7 +716,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::shared_ptr<UnboundedWorkQueue> work_queue_;
std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_;
CollectiveParams* col_params_;
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
std::unique_ptr<string> gpu_ring_order_;
@ -720,33 +728,35 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
};
TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams1Task8GPU) {
CollectiveParams cp;
PrepColParamsForSubdivPermsTest(&cp, 1, 8);
auto* cp = new CollectiveParams();
core::ScopedUnref unref(cp);
PrepColParamsForSubdivPermsTest(cp, 1, 8);
// source 0 device 0
cp.source_rank = 0;
cp.default_rank = 0;
RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0});
cp->source_rank = 0;
cp->default_rank = 0;
RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0});
// source 2 device 2
cp.source_rank = 2;
cp.default_rank = 2;
RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2});
cp->source_rank = 2;
cp->default_rank = 2;
RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2});
// source 2 device 0
cp.source_rank = 2;
cp.default_rank = 0;
RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2});
cp->source_rank = 2;
cp->default_rank = 0;
RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2});
}
TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
CollectiveParams cp;
PrepColParamsForSubdivPermsTest(&cp, 4, 8);
auto* cp = new CollectiveParams();
core::ScopedUnref unref(cp);
PrepColParamsForSubdivPermsTest(cp, 4, 8);
// source 0 device 0
cp.source_rank = 0;
cp.default_rank = 0;
RunSubdivPermsTest(&cp,
cp->source_rank = 0;
cp->default_rank = 0;
RunSubdivPermsTest(cp,
{{0, 8, 16, 24},
{0, 1, 2, 3, 4, 5, 6, 7},
{8, 9, 10, 11, 12, 13, 14, 15},
@ -755,9 +765,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
{0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
// source 2 device 0
cp.source_rank = 2;
cp.default_rank = 0;
RunSubdivPermsTest(&cp,
cp->source_rank = 2;
cp->default_rank = 0;
RunSubdivPermsTest(cp,
{{2, 8, 16, 24},
{0, 1, 2, 3, 4, 5, 6, 7},
{8, 9, 10, 11, 12, 13, 14, 15},
@ -766,9 +776,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
{-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
// source 9 device 9
cp.source_rank = 9;
cp.default_rank = 9;
RunSubdivPermsTest(&cp,
cp->source_rank = 9;
cp->default_rank = 9;
RunSubdivPermsTest(cp,
{{0, 9, 16, 24},
{0, 1, 2, 3, 4, 5, 6, 7},
{8, 9, 10, 11, 12, 13, 14, 15},
@ -778,28 +788,29 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
}
TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
CollectiveParams cp;
auto* cp = new CollectiveParams();
core::ScopedUnref unref(cp);
int num_tasks = 4;
cp.group.device_type = DeviceType("GPU");
cp.group.num_tasks = num_tasks;
cp.group.group_size = 0;
cp.instance.type = BROADCAST_COLLECTIVE;
cp.instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
cp->group.device_type = DeviceType("GPU");
cp->group.num_tasks = num_tasks;
cp->group.group_size = 0;
cp->instance.type = BROADCAST_COLLECTIVE;
cp->instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
std::vector<int> dev_per_task = {4, 4, 6, 8};
for (int ti = 0; ti < cp.group.num_tasks; ti++) {
for (int ti = 0; ti < cp->group.num_tasks; ti++) {
string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
for (int di = 0; di < dev_per_task[ti]; di++) {
string dev_name = strings::StrCat(task_name, "/device:GPU:", di);
cp.group.task_names.push_back(task_name);
cp.group.device_names.push_back(dev_name);
cp.group.group_size++;
cp->group.task_names.push_back(task_name);
cp->group.device_names.push_back(dev_name);
cp->group.group_size++;
}
}
// source 0 device 0
cp.source_rank = 0;
cp.default_rank = 0;
RunSubdivPermsTest(&cp,
cp->source_rank = 0;
cp->default_rank = 0;
RunSubdivPermsTest(cp,
{{0, 4, 8, 14},
{0, 1, 2, 3},
{4, 5, 6, 7},
@ -808,9 +819,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
{0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
// source 2 device 0
cp.source_rank = 2;
cp.default_rank = 0;
RunSubdivPermsTest(&cp,
cp->source_rank = 2;
cp->default_rank = 0;
RunSubdivPermsTest(cp,
{{2, 4, 8, 14},
{0, 1, 2, 3},
{4, 5, 6, 7},
@ -819,9 +830,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
{-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
// source 9 device 5
cp.source_rank = 9;
cp.default_rank = 5;
RunSubdivPermsTest(&cp,
cp->source_rank = 9;
cp->default_rank = 5;
RunSubdivPermsTest(cp,
{{0, 4, 9, 14},
{0, 1, 2, 3},
{4, 5, 6, 7},

View File

@ -54,7 +54,7 @@ Status Permuter::InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) {
DCHECK(col_ctx->dev_mgr);
col_ctx_ = col_ctx;
col_params_ = &col_ctx->col_params;
col_params_ = col_ctx->col_params;
return collective_util::InitializeDeviceAndLocality(
col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
&col_ctx->device_locality);

View File

@ -107,12 +107,14 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
class PermuterTest : public ::testing::Test {
protected:
PermuterTest() : device_type_(DEVICE_CPU) {}
PermuterTest()
: device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
~PermuterTest() override {
stop_ = true;
for (auto i : instances_) delete i;
if (col_exec_) col_exec_->Unref();
if (col_params_) col_params_->Unref();
}
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -170,12 +172,13 @@ class PermuterTest : public ::testing::Test {
col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
dev_mgr_.get(),
gpu_ring_order_.get(), work_queue_);
col_params_.name = "test_collective";
col_params_.instance.data_type = dtype;
col_params_ = new CollectiveParams();
col_params_->name = "test_collective";
col_params_->instance.data_type = dtype;
static const int kInstanceKey = 18;
col_params_.instance.instance_key = kInstanceKey;
col_params_.group.device_type = device_type;
col_params_.instance.type = PERMUTE_COLLECTIVE;
col_params_->instance.instance_key = kInstanceKey;
col_params_->group.device_type = device_type;
col_params_->instance.type = PERMUTE_COLLECTIVE;
// Set up all the fake device contexts.
for (int wi = 0; wi < num_workers; wi++) {
@ -187,12 +190,12 @@ class PermuterTest : public ::testing::Test {
} else {
dev_name = strings::StrCat(task_name, "/device:CPU:", di);
}
col_params_.group.device_names.push_back(dev_name);
col_params_.instance.devices.push_back(dev_name);
col_params_->group.device_names.push_back(dev_name);
col_params_->instance.devices.push_back(dev_name);
int default_rank = wi * num_devices_per_worker + di;
permutation_.push_back(default_rank);
col_params_.group.task_names.push_back(task_name);
col_params_.task.is_local.push_back(true);
col_params_->group.task_names.push_back(task_name);
col_params_->task.is_local.push_back(true);
}
}
@ -210,13 +213,13 @@ class PermuterTest : public ::testing::Test {
std::next_permutation(permutation_.begin() + i,
permutation_.begin() + i + 2);
}
col_params_.instance.permutation = permutation_;
col_params_->instance.permutation = permutation_;
for (int wi = 0; wi < num_workers; wi++) {
for (int di = 0; di < num_devices_per_worker; di++) {
int default_rank = wi * num_devices_per_worker + di;
instances_.push_back(new DeviceInstance(
default_rank, col_params_.group.device_names[default_rank],
default_rank, col_params_->group.device_names[default_rank],
device_type, this));
}
}
@ -320,25 +323,30 @@ class PermuterTest : public ::testing::Test {
: parent_(parent),
dev_name_(dev_name),
device_type_(device_type),
rank_(rank) {
rank_(rank),
col_params_(new CollectiveParams()) {
TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_));
col_params_.name = parent_->col_params_.name;
col_params_.instance.data_type = parent_->col_params_.instance.data_type;
col_params_.instance.instance_key =
parent_->col_params_.instance.instance_key;
col_params_.group.device_type = parent_->col_params_.group.device_type;
col_params_.group.device_names = parent_->col_params_.group.device_names;
col_params_.instance.devices = parent_->col_params_.instance.devices;
col_params_.instance.permutation =
parent->col_params_.instance.permutation;
col_params_.group.task_names = parent_->col_params_.group.task_names;
col_params_.task.is_local = parent_->col_params_.task.is_local;
CHECK_EQ(col_params_.instance.devices.size(),
col_params_.group.device_names.size());
col_params_->name = parent_->col_params_->name;
col_params_->instance.data_type =
parent_->col_params_->instance.data_type;
col_params_->instance.instance_key =
parent_->col_params_->instance.instance_key;
col_params_->group.device_type = parent_->col_params_->group.device_type;
col_params_->group.device_names =
parent_->col_params_->group.device_names;
col_params_->instance.devices = parent_->col_params_->instance.devices;
col_params_->instance.permutation =
parent->col_params_->instance.permutation;
col_params_->group.task_names = parent_->col_params_->group.task_names;
col_params_->task.is_local = parent_->col_params_->task.is_local;
CHECK_EQ(col_params_->instance.devices.size(),
col_params_->group.device_names.size());
// Default rank is order in device_names.
col_params_.default_rank = rank;
col_params_->default_rank = rank;
}
~DeviceInstance() { col_params_->Unref(); }
void InitTensor(DataType dtype, const TensorShape& shape,
const InitFunc& f) {
tensor_input_ =
@ -387,7 +395,7 @@ class PermuterTest : public ::testing::Test {
// Prepare a Permuter instance.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
strings::StrCat(col_params_->instance.instance_key, ":0:0");
Permuter* permuter = new Permuter;
core::ScopedUnref unref(permuter);
auto col_ctx = std::make_shared<CollectiveContext>(
@ -412,7 +420,7 @@ class PermuterTest : public ::testing::Test {
Tensor tensor_input_;
Tensor tensor_output_;
Device* device_;
CollectiveParams col_params_;
CollectiveParams* col_params_;
Status status_;
}; // class DeviceInstance
@ -425,7 +433,7 @@ class PermuterTest : public ::testing::Test {
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::shared_ptr<UnboundedWorkQueue> work_queue_;
std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_;
CollectiveParams* col_params_;
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
std::unique_ptr<string> gpu_ring_order_;

View File

@ -245,7 +245,7 @@ Status RingAlg::InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) {
DCHECK(col_ctx->dev_mgr);
col_ctx_ = col_ctx;
col_params_ = &col_ctx->col_params;
col_params_ = col_ctx->col_params;
return collective_util::InitializeDeviceAndLocality(
col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
&col_ctx->device_locality);

View File

@ -115,7 +115,8 @@ static int64 kStepId = 123;
class RingGathererTest : public ::testing::Test {
protected:
RingGathererTest() : device_type_(DEVICE_CPU) {}
RingGathererTest()
: device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
void InitGPUDevices() {
@ -132,6 +133,7 @@ class RingGathererTest : public ::testing::Test {
stop_ = true;
for (auto i : instances_) delete i;
if (col_exec_) col_exec_->Unref();
if (col_params_) col_params_->Unref();
}
void Init(int num_workers, int num_devices, DataType dtype,
@ -180,24 +182,25 @@ class RingGathererTest : public ::testing::Test {
col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
dev_mgr_.get(),
gpu_ring_order_.get(), work_queue_);
col_params_.name = "test_collective";
col_params_ = new CollectiveParams();
col_params_->name = "test_collective";
static const int kGroupKey = 5;
col_params_.group.group_key = kGroupKey;
col_params_.group.device_type = device_type;
col_params_.group.group_size = num_workers * num_devices;
col_params_->group.group_key = kGroupKey;
col_params_->group.device_type = device_type;
col_params_->group.group_size = num_workers * num_devices;
static const int kInstanceKey = 17;
col_params_.instance.instance_key = kInstanceKey;
col_params_.instance.impl_details.subdiv_offsets.clear();
col_params_.instance.type = GATHER_COLLECTIVE;
col_params_.instance.impl_details.collective_name = "RingGather";
col_params_.instance.data_type = dtype;
col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
col_params_.subdiv_rank.resize(num_subdivs);
col_params_->instance.instance_key = kInstanceKey;
col_params_->instance.impl_details.subdiv_offsets.clear();
col_params_->instance.type = GATHER_COLLECTIVE;
col_params_->instance.impl_details.collective_name = "RingGather";
col_params_->instance.data_type = dtype;
col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs);
col_params_->subdiv_rank.resize(num_subdivs);
int subdiv_stride = num_devices / num_subdivs;
for (int sdi = 0; sdi < num_subdivs; ++sdi) {
col_params_.instance.impl_details.subdiv_offsets.push_back(sdi *
subdiv_stride);
col_params_.subdiv_rank[sdi] = sdi * subdiv_stride;
col_params_->instance.impl_details.subdiv_offsets.push_back(
sdi * subdiv_stride);
col_params_->subdiv_rank[sdi] = sdi * subdiv_stride;
}
// Set up a local device ring order that's not just 0,1,2...
@ -225,16 +228,16 @@ class RingGathererTest : public ::testing::Test {
dev_name =
strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size());
}
col_params_.group.device_names.push_back(dev_name);
col_params_.group.task_names.push_back(task_name);
col_params_->group.device_names.push_back(dev_name);
col_params_->group.task_names.push_back(task_name);
// Normally each device would set is_local to its own perspective but
// this test runs in a single process so is_local is always true.
col_params_.task.is_local.push_back(true);
col_params_->task.is_local.push_back(true);
for (int sdi = 0; sdi < num_subdivs; ++sdi) {
int rotated_di =
(di + col_params_.instance.impl_details.subdiv_offsets[sdi]) %
(di + col_params_->instance.impl_details.subdiv_offsets[sdi]) %
num_devices;
col_params_.instance.impl_details.subdiv_permutations[sdi].push_back(
col_params_->instance.impl_details.subdiv_permutations[sdi].push_back(
wi * num_devices + local_ring_order[rotated_di]);
}
}
@ -243,7 +246,7 @@ class RingGathererTest : public ::testing::Test {
for (int di = 0; di < num_devices; ++di) {
int rank = wi * num_devices + di;
instances_.push_back(new DeviceInstance(
rank, col_params_.group.device_names[rank], device_type_, this));
rank, col_params_->group.device_names[rank], device_type_, this));
}
}
}
@ -387,39 +390,42 @@ class RingGathererTest : public ::testing::Test {
: parent_(parent),
dev_name_(dev_name),
device_type_(device_type),
rank_(rank) {
rank_(rank),
col_params_(new CollectiveParams()) {
TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_))
<< "Couldn't find device " << dev_name
<< " existing devices: " << parent_->dev_mgr_->DebugString();
col_params_.name = parent_->col_params_.name;
col_params_.group = parent_->col_params_.group;
col_params_.instance = parent->col_params_.instance;
col_params_.task.is_local = parent_->col_params_.task.is_local;
col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
col_params_->name = parent_->col_params_->name;
col_params_->group = parent_->col_params_->group;
col_params_->instance = parent->col_params_->instance;
col_params_->task.is_local = parent_->col_params_->task.is_local;
col_params_->subdiv_rank = parent_->col_params_->subdiv_rank;
int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
int group_size = col_params_.group.group_size;
int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
int group_size = col_params_->group.group_size;
CHECK_EQ(group_size,
static_cast<int>(col_params_.group.device_names.size()));
static_cast<int>(col_params_->group.device_names.size()));
// Id of this device is at rank position in first subdiv perm.
int my_device_id =
col_params_.instance.impl_details.subdiv_permutations[0][rank];
col_params_.default_rank = my_device_id;
col_params_->instance.impl_details.subdiv_permutations[0][rank];
col_params_->default_rank = my_device_id;
// Set rank for all other subdivs by finding that device_id.
for (int sdi = 0; sdi < num_subdivs; ++sdi) {
for (int r = 0; r < static_cast<int>(col_params_.instance.impl_details
for (int r = 0; r < static_cast<int>(col_params_->instance.impl_details
.subdiv_permutations[sdi]
.size());
++r) {
if (my_device_id ==
col_params_.instance.impl_details.subdiv_permutations[sdi][r]) {
col_params_.subdiv_rank[sdi] = r;
col_params_->instance.impl_details.subdiv_permutations[sdi][r]) {
col_params_->subdiv_rank[sdi] = r;
break;
}
}
}
}
~DeviceInstance() { col_params_->Unref(); }
void InitTensor(DataType dtype, const TensorShape& shape,
const std::function<void(Tensor*)>& init_f) {
input_tensor_ =
@ -464,7 +470,7 @@ class RingGathererTest : public ::testing::Test {
AllocatorAttributes generic_alloc_attr;
op_params.output_attr_array = &generic_alloc_attr;
std::unique_ptr<OpKernel> op = parent_->GetCollectiveGather(
col_params_, &input_tensor_, DEVICE_CPU, device_);
*col_params_, &input_tensor_, DEVICE_CPU, device_);
op_params.op_kernel = op.get();
OpKernelContext ctx(&op_params, 1);
@ -478,7 +484,7 @@ class RingGathererTest : public ::testing::Test {
CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
// Prepare a RingGatherer instance.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
strings::StrCat(col_params_->instance.instance_key, ":0:0");
RingGatherer* gatherer = new RingGatherer;
core::ScopedUnref unref(gatherer);
auto col_ctx = std::make_shared<CollectiveContext>(
@ -507,7 +513,7 @@ class RingGathererTest : public ::testing::Test {
Tensor input_tensor_;
Tensor output_tensor_;
Device* device_;
CollectiveParams col_params_;
CollectiveParams* col_params_;
std::unique_ptr<CollectiveAdapter> ca_;
std::unique_ptr<OpKernelContext> ctx_;
Status status_;
@ -521,7 +527,7 @@ class RingGathererTest : public ::testing::Test {
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::shared_ptr<UnboundedWorkQueue> work_queue_;
std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_;
CollectiveParams* col_params_;
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
std::unique_ptr<string> gpu_ring_order_;
@ -530,28 +536,28 @@ class RingGathererTest : public ::testing::Test {
CancellationManager cancellation_manager_;
};
CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
CollectiveParams* SetUpCollectiveParams(const int num_devs_per_task,
const int num_tasks) {
CollectiveParams cp;
auto* cp = new CollectiveParams();
const int kNumDevs = num_devs_per_task * num_tasks;
cp.group.group_key = 1;
cp.group.group_size = kNumDevs;
cp.group.device_type = DeviceType("GPU");
cp.group.num_tasks = num_tasks;
cp.instance.instance_key = 3;
cp.instance.type = GATHER_COLLECTIVE;
cp.instance.data_type = DataType(DT_FLOAT);
cp.instance.shape = TensorShape({kNumDevs * kNumDevs});
cp.instance.impl_details.collective_name = "RingGather";
cp.instance.impl_details.subdiv_offsets.push_back(0);
cp.is_source = false;
cp->group.group_key = 1;
cp->group.group_size = kNumDevs;
cp->group.device_type = DeviceType("GPU");
cp->group.num_tasks = num_tasks;
cp->instance.instance_key = 3;
cp->instance.type = GATHER_COLLECTIVE;
cp->instance.data_type = DataType(DT_FLOAT);
cp->instance.shape = TensorShape({kNumDevs * kNumDevs});
cp->instance.impl_details.collective_name = "RingGather";
cp->instance.impl_details.subdiv_offsets.push_back(0);
cp->is_source = false;
for (int i = 0; i < kNumDevs; ++i) {
int task_id = i / num_devs_per_task;
int dev_id = i % num_devs_per_task;
string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
cp.group.task_names.push_back(task_name);
cp.group.device_names.push_back(device_name);
cp->group.task_names.push_back(task_name);
cp->group.device_names.push_back(device_name);
}
return cp;
}
@ -559,22 +565,23 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
TEST_F(RingGathererTest, InitializeParams) {
const int kNumDevsPerTask = 8;
const int kNumTasks = 3;
CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
core::ScopedUnref unref(cp);
cp.default_rank = 0;
cp.instance.impl_details.subdiv_offsets = {};
RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
cp->default_rank = 0;
cp->instance.impl_details.subdiv_offsets = {};
RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
{0});
cp.instance.impl_details.subdiv_offsets = {0};
RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
cp->instance.impl_details.subdiv_offsets = {0};
RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
{0});
cp.default_rank = 3;
cp.instance.impl_details.subdiv_offsets = {};
RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
cp->default_rank = 3;
cp->instance.impl_details.subdiv_offsets = {};
RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
{3});
}

View File

@ -138,7 +138,8 @@ static int64 kStepId = 123;
class RingReducerTest : public ::testing::Test {
protected:
RingReducerTest() : device_type_(DEVICE_CPU) {}
RingReducerTest()
: device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
void InitGPUDevices() {
@ -155,6 +156,7 @@ class RingReducerTest : public ::testing::Test {
stop_ = true;
for (auto i : instances_) delete i;
if (col_exec_) col_exec_->Unref();
if (col_params_) col_params_->Unref();
}
void Init(int num_workers, int num_devices, DataType dtype,
@ -203,24 +205,25 @@ class RingReducerTest : public ::testing::Test {
col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
dev_mgr_.get(),
gpu_ring_order_.get(), work_queue_);
col_params_.name = "test_collective";
col_params_ = new CollectiveParams();
col_params_->name = "test_collective";
static const int kGroupKey = 5;
col_params_.group.group_key = kGroupKey;
col_params_.group.device_type = device_type;
col_params_.group.group_size = num_workers * num_devices;
col_params_->group.group_key = kGroupKey;
col_params_->group.device_type = device_type;
col_params_->group.group_size = num_workers * num_devices;
static const int kInstanceKey = 17;
col_params_.instance.instance_key = kInstanceKey;
col_params_.instance.impl_details.subdiv_offsets.clear();
col_params_.instance.type = REDUCTION_COLLECTIVE;
col_params_.instance.impl_details.collective_name = "RingReduce";
col_params_.instance.data_type = dtype;
col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
col_params_.subdiv_rank.resize(num_subdivs);
col_params_->instance.instance_key = kInstanceKey;
col_params_->instance.impl_details.subdiv_offsets.clear();
col_params_->instance.type = REDUCTION_COLLECTIVE;
col_params_->instance.impl_details.collective_name = "RingReduce";
col_params_->instance.data_type = dtype;
col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs);
col_params_->subdiv_rank.resize(num_subdivs);
int subdiv_stride = num_devices / num_subdivs;
for (int sdi = 0; sdi < num_subdivs; ++sdi) {
col_params_.instance.impl_details.subdiv_offsets.push_back(sdi *
subdiv_stride);
col_params_.subdiv_rank[sdi] = sdi * subdiv_stride;
col_params_->instance.impl_details.subdiv_offsets.push_back(
sdi * subdiv_stride);
col_params_->subdiv_rank[sdi] = sdi * subdiv_stride;
}
// Set up a local device ring order that's not just 0,1,2...
@ -242,23 +245,23 @@ class RingReducerTest : public ::testing::Test {
// Set up all of the fake device contexts.
for (int wi = 0; wi < num_workers; ++wi) {
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
col_params_.group.num_devices_per_task[task_name] = num_devices;
col_params_->group.num_devices_per_task[task_name] = num_devices;
for (int di = 0; di < num_devices; ++di) {
string dev_name = strings::StrCat(task_name, "/cpu:", di);
if (device_type == DEVICE_GPU) {
dev_name =
strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size());
}
col_params_.group.device_names.push_back(dev_name);
col_params_.group.task_names.push_back(task_name);
col_params_->group.device_names.push_back(dev_name);
col_params_->group.task_names.push_back(task_name);
// Normally each device would set is_local to its own perspective but
// this test runs in a single process so is_local is always true.
col_params_.task.is_local.push_back(true);
col_params_->task.is_local.push_back(true);
for (int sdi = 0; sdi < num_subdivs; ++sdi) {
int rotated_di =
(di + col_params_.instance.impl_details.subdiv_offsets[sdi]) %
(di + col_params_->instance.impl_details.subdiv_offsets[sdi]) %
num_devices;
col_params_.instance.impl_details.subdiv_permutations[sdi].push_back(
col_params_->instance.impl_details.subdiv_permutations[sdi].push_back(
wi * num_devices + local_ring_order[rotated_di]);
}
}
@ -267,7 +270,7 @@ class RingReducerTest : public ::testing::Test {
for (int di = 0; di < num_devices; ++di) {
int rank = wi * num_devices + di;
instances_.push_back(new DeviceInstance(
rank, col_params_.group.device_names[rank], device_type_, this));
rank, col_params_->group.device_names[rank], device_type_, this));
}
}
}
@ -413,39 +416,42 @@ class RingReducerTest : public ::testing::Test {
: parent_(parent),
dev_name_(dev_name),
device_type_(device_type),
rank_(rank) {
rank_(rank),
col_params_(new CollectiveParams()) {
TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_))
<< "Couldn't find device " << dev_name
<< " existing devices: " << parent_->dev_mgr_->DebugString();
col_params_.name = parent_->col_params_.name;
col_params_.group = parent_->col_params_.group;
col_params_.instance = parent->col_params_.instance;
col_params_.task.is_local = parent_->col_params_.task.is_local;
col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
col_params_->name = parent_->col_params_->name;
col_params_->group = parent_->col_params_->group;
col_params_->instance = parent->col_params_->instance;
col_params_->task.is_local = parent_->col_params_->task.is_local;
col_params_->subdiv_rank = parent_->col_params_->subdiv_rank;
int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
int group_size = col_params_.group.group_size;
int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
int group_size = col_params_->group.group_size;
CHECK_EQ(group_size,
static_cast<int>(col_params_.group.device_names.size()));
static_cast<int>(col_params_->group.device_names.size()));
// Id of this device is at rank position in first subdiv perm.
int my_device_id =
col_params_.instance.impl_details.subdiv_permutations[0][rank];
col_params_.default_rank = my_device_id;
col_params_->instance.impl_details.subdiv_permutations[0][rank];
col_params_->default_rank = my_device_id;
// Set rank for all other subdivs by finding that device_id.
for (int sdi = 0; sdi < num_subdivs; ++sdi) {
for (int r = 0; r < static_cast<int>(col_params_.instance.impl_details
for (int r = 0; r < static_cast<int>(col_params_->instance.impl_details
.subdiv_permutations[sdi]
.size());
++r) {
if (my_device_id ==
col_params_.instance.impl_details.subdiv_permutations[sdi][r]) {
col_params_.subdiv_rank[sdi] = r;
col_params_->instance.impl_details.subdiv_permutations[sdi][r]) {
col_params_->subdiv_rank[sdi] = r;
break;
}
}
}
}
~DeviceInstance() { col_params_->Unref(); }
void InitTensor(DataType dtype, const TensorShape& shape,
const std::function<void(Tensor*)>& init_f) {
tensor_ =
@ -466,10 +472,12 @@ class RingReducerTest : public ::testing::Test {
}
void DoReduce() {
merge_op_ = GetAdd(col_params_.instance.data_type, device_type_, device_);
final_op_ = GetDiv(col_params_.instance.data_type, device_type_, device_);
col_params_.merge_op = merge_op_.get();
col_params_.final_op = final_op_.get();
merge_op_ =
GetAdd(col_params_->instance.data_type, device_type_, device_);
final_op_ =
GetDiv(col_params_->instance.data_type, device_type_, device_);
col_params_->merge_op = merge_op_.get();
col_params_->final_op = final_op_.get();
// Prepare an OpKernelContext.
OpKernelContext::Params op_params;
@ -496,7 +504,7 @@ class RingReducerTest : public ::testing::Test {
AllocatorAttributes generic_alloc_attr;
op_params.output_attr_array = &generic_alloc_attr;
std::unique_ptr<OpKernel> op = parent_->GetCollectiveReduce(
col_params_, &tensor_, DEVICE_CPU, device_);
*col_params_, &tensor_, DEVICE_CPU, device_);
op_params.op_kernel = op.get();
OpKernelContext ctx(&op_params, 1);
@ -509,7 +517,7 @@ class RingReducerTest : public ::testing::Test {
// Prepare a RingReducer instance.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
strings::StrCat(col_params_->instance.instance_key, ":0:0");
RingReducer* reducer = new RingReducer;
core::ScopedUnref unref(reducer);
auto col_ctx = std::make_shared<CollectiveContext>(
@ -535,7 +543,7 @@ class RingReducerTest : public ::testing::Test {
int rank_;
Tensor tensor_;
Device* device_;
CollectiveParams col_params_;
CollectiveParams* col_params_;
std::unique_ptr<OpKernel> merge_op_;
std::unique_ptr<OpKernel> final_op_;
std::unique_ptr<CollectiveAdapter> ca_;
@ -551,7 +559,7 @@ class RingReducerTest : public ::testing::Test {
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::shared_ptr<UnboundedWorkQueue> work_queue_;
std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_;
CollectiveParams* col_params_;
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
std::unique_ptr<string> gpu_ring_order_;
@ -560,28 +568,28 @@ class RingReducerTest : public ::testing::Test {
CancellationManager cancellation_manager_;
};
CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
CollectiveParams* SetUpCollectiveParams(const int num_devs_per_task,
const int num_tasks) {
CollectiveParams cp;
auto cp = new CollectiveParams();
const int kNumDevs = num_devs_per_task * num_tasks;
cp.group.group_key = 1;
cp.group.group_size = kNumDevs;
cp.group.device_type = DeviceType("GPU");
cp.group.num_tasks = num_tasks;
cp.instance.instance_key = 3;
cp.instance.type = REDUCTION_COLLECTIVE;
cp.instance.data_type = DataType(DT_FLOAT);
cp.instance.shape = TensorShape({kNumDevs});
cp.instance.impl_details.collective_name = "RingReduce";
cp.instance.impl_details.subdiv_offsets.push_back(0);
cp.is_source = false;
cp->group.group_key = 1;
cp->group.group_size = kNumDevs;
cp->group.device_type = DeviceType("GPU");
cp->group.num_tasks = num_tasks;
cp->instance.instance_key = 3;
cp->instance.type = REDUCTION_COLLECTIVE;
cp->instance.data_type = DataType(DT_FLOAT);
cp->instance.shape = TensorShape({kNumDevs});
cp->instance.impl_details.collective_name = "RingReduce";
cp->instance.impl_details.subdiv_offsets.push_back(0);
cp->is_source = false;
for (int i = 0; i < kNumDevs; ++i) {
int task_id = i / num_devs_per_task;
int dev_id = i % num_devs_per_task;
string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
cp.group.task_names.push_back(task_name);
cp.group.device_names.push_back(device_name);
cp->group.task_names.push_back(task_name);
cp->group.device_names.push_back(device_name);
}
return cp;
}
@ -589,28 +597,29 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
TEST_F(RingReducerTest, InitializeParams) {
const int kNumDevsPerTask = 8;
const int kNumTasks = 3;
CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
core::ScopedUnref unref(cp);
cp.default_rank = 0;
cp.instance.impl_details.subdiv_offsets = {0, 4};
RunSubdivPermsTest(&cp,
cp->default_rank = 0;
cp->instance.impl_details.subdiv_offsets = {0, 4};
RunSubdivPermsTest(cp,
{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15,
8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}},
{0, 4});
cp.instance.impl_details.subdiv_offsets = {0, -4};
RunSubdivPermsTest(&cp,
cp->instance.impl_details.subdiv_offsets = {0, -4};
RunSubdivPermsTest(cp,
{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
{0, 3});
cp.default_rank = 3;
cp.instance.impl_details.subdiv_offsets = {3, -3};
RunSubdivPermsTest(&cp,
cp->default_rank = 3;
cp->instance.impl_details.subdiv_offsets = {3, -3};
RunSubdivPermsTest(cp,
{{3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18},
{4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9,
@ -622,12 +631,13 @@ TEST_F(RingReducerTest, AutomaticSubdivs) {
const int kNumDevsPerTask = 8;
const int kNumTasks = 3;
const int kNumDevs = kNumDevsPerTask * kNumTasks;
CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
core::ScopedUnref unref(cp);
// Test automatic generation of subdiv offsets.
cp.default_rank = 0;
cp.instance.impl_details.subdiv_offsets.clear();
RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
cp->default_rank = 0;
cp->instance.impl_details.subdiv_offsets.clear();
RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
{0});
@ -638,11 +648,11 @@ TEST_F(RingReducerTest, AutomaticSubdivs) {
int num_chunks = kNumDevs * num_subdivs;
size_t chunk_size = 3 * 1048576; // 3 MB
size_t tensor_size = chunk_size * num_chunks;
cp.instance.shape =
cp->instance.shape =
TensorShape({static_cast<int64>(tensor_size / DataTypeSize(DT_FLOAT))});
}
cp.instance.impl_details.subdiv_offsets.clear();
RunSubdivPermsTest(&cp,
cp->instance.impl_details.subdiv_offsets.clear();
RunSubdivPermsTest(cp,
{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
@ -653,12 +663,13 @@ TEST_F(RingReducerTest, AutomaticSubdivs) {
TEST_F(RingReducerTest, AutomaticSubdivUpperBound) {
const int kNumDevsPerTask = 1;
const int kNumTasks = 4;
CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
core::ScopedUnref unref(cp);
cp.default_rank = 0;
cp.instance.impl_details.subdiv_offsets.clear();
cp.instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
RunSubdivPermsTest(&cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
cp->default_rank = 0;
cp->instance.impl_details.subdiv_offsets.clear();
cp->instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
RunSubdivPermsTest(cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
}
// TODO(b/113171733): change to use TEST_P.

View File

@ -137,13 +137,14 @@ void CollectiveParamResolverDistributed::CompleteGroupAsync(
"running the same version of Tensorflow on all workers."));
return;
}
CollectiveParams cp;
cp.group.group_key = request->group_key();
cp.group.group_size = request->group_size();
cp.group.device_type = DeviceType(request->device_type());
cp.instance.type = CollectiveType(request->collective_type());
auto* cp = new CollectiveParams();
core::ScopedUnref unref(cp);
cp->group.group_key = request->group_key();
cp->group.group_size = request->group_size();
cp->group.device_type = DeviceType(request->device_type());
cp->instance.type = CollectiveType(request->collective_type());
CompleteGroupDistributed(
request->device_attributes(), &cp, cancel_mgr,
request->device_attributes(), cp, cancel_mgr,
[response, done](const Status& s, const GroupRec* gr) {
if (s.ok()) {
mutex_lock l(gr->mu);
@ -196,7 +197,7 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync(
}
StatusCallback done_and_cleanup = [cp, done](const Status& s) {
done(s);
delete cp;
cp->Unref();
};
CompleteInstanceDistributed(
request->device(), gr, cp, cancel_mgr,

View File

@ -127,6 +127,13 @@ class FakeCache : public TestWorkerCache {
};
class DeviceResDistTest : public ::testing::Test {
public:
~DeviceResDistTest() override {
for (auto& name_param : cp_) {
name_param.second->Unref();
}
}
protected:
void DefineWorkers(int num_workers, int num_devices,
const string& device_type, bool nccl) {
@ -181,20 +188,20 @@ class DeviceResDistTest : public ::testing::Test {
}
}
CollectiveParams CreateCollectiveParams(int num_workers, int num_devices,
CollectiveParams* CreateCollectiveParams(int num_workers, int num_devices,
const string& device_type) {
const int kGroupKey = 5;
const int kInstanceKey = 3;
CollectiveParams cp;
cp.group.group_key = kGroupKey;
cp.group.group_size = num_workers * num_devices;
cp.group.device_type = DeviceType(device_type);
cp.group.num_tasks = num_workers;
cp.instance.instance_key = kInstanceKey;
cp.instance.type = REDUCTION_COLLECTIVE;
cp.instance.data_type = DT_FLOAT;
cp.instance.shape = TensorShape({64});
cp.instance.impl_details.subdiv_offsets.push_back(0);
auto* cp = new CollectiveParams();
cp->group.group_key = kGroupKey;
cp->group.group_size = num_workers * num_devices;
cp->group.device_type = DeviceType(device_type);
cp->group.num_tasks = num_workers;
cp->instance.instance_key = kInstanceKey;
cp->instance.type = REDUCTION_COLLECTIVE;
cp->instance.data_type = DT_FLOAT;
cp->instance.shape = TensorShape({64});
cp->instance.impl_details.subdiv_offsets.push_back(0);
return cp;
}
@ -217,7 +224,7 @@ class DeviceResDistTest : public ::testing::Test {
int group_size) {
Device* device = nullptr;
TF_CHECK_OK(device_mgrs_[task_name]->LookupDevice(device_name, &device));
CollectiveParams* cp = &cp_[device_name];
CollectiveParams* cp = cp_[device_name];
CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name].get();
CHECK(cp_res);
cp_res->CompleteParamsAsync(
@ -252,19 +259,19 @@ class DeviceResDistTest : public ::testing::Test {
string device_name = strings::StrCat(task_name, "/device:CPU:", di);
int idx = wi * num_devices + di;
TF_ASSERT_OK(status_[device_name]);
EXPECT_EQ(cp_[device_name].default_rank, idx);
EXPECT_EQ(cp_[device_name].group.device_names.size(), dev_count);
EXPECT_EQ(cp_[device_name].group.device_names[idx], device_name);
EXPECT_EQ(cp_[device_name].group.task_names[idx], task_name);
ValidateDeviceResolver(cp_[device_name], task_name);
EXPECT_EQ(cp_[device_name]->default_rank, idx);
EXPECT_EQ(cp_[device_name]->group.device_names.size(), dev_count);
EXPECT_EQ(cp_[device_name]->group.device_names[idx], device_name);
EXPECT_EQ(cp_[device_name]->group.task_names[idx], task_name);
ValidateDeviceResolver(*cp_[device_name], task_name);
if (idx > 0) {
EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key,
cp_[device_name].group.runtime_details.communicator_key);
EXPECT_EQ(cp_[dev0]->group.runtime_details.communicator_key,
cp_[device_name]->group.runtime_details.communicator_key);
for (int i = 0; i < dev_count; ++i) {
EXPECT_EQ(cp_[dev0].group.device_names[i],
cp_[device_name].group.device_names[i]);
EXPECT_EQ(cp_[dev0].group.task_names[i],
cp_[device_name].group.task_names[i]);
EXPECT_EQ(cp_[dev0]->group.device_names[i],
cp_[device_name]->group.device_names[i]);
EXPECT_EQ(cp_[dev0]->group.task_names[i],
cp_[device_name]->group.task_names[i]);
}
}
}
@ -287,6 +294,9 @@ class DeviceResDistTest : public ::testing::Test {
for (int i = 0; i < num_devices; ++i) {
string device_name =
strings::StrCat(worker_name, "/device:", device_type, ":", i);
if (cp_.find(device_name) != cp_.end()) {
cp_[device_name]->Unref();
}
cp_[device_name] =
CreateCollectiveParams(num_workers, num_devices, device_type);
status_.erase(device_name);
@ -305,7 +315,7 @@ class DeviceResDistTest : public ::testing::Test {
absl::flat_hash_map<string, std::vector<string>> dev_by_task_;
absl::flat_hash_map<string, std::unique_ptr<FakeWorker>> workers_;
// Below are keyed by device names;
absl::flat_hash_map<string, CollectiveParams> cp_;
absl::flat_hash_map<string, CollectiveParams*> cp_;
absl::flat_hash_map<string, Status> status_;
mutex mu_;
int num_done_ TF_GUARDED_BY(mu_);

View File

@ -169,7 +169,7 @@ string CollectiveParams::ToString() const {
CollectiveContext::CollectiveContext(
CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
const DeviceMgr* dev_mgr, OpKernelContext* ctx,
OpKernelContext::Params* op_params, const CollectiveParams& col_params,
OpKernelContext::Params* op_params, const CollectiveParams* col_params,
const string& exec_key, int64 step_id, const Tensor* input, Tensor* output)
: col_exec(col_exec),
nccl_communicator(nccl_communicator),
@ -182,7 +182,7 @@ CollectiveContext::CollectiveContext(
input(input),
output(output),
device(nullptr),
device_name(col_params.group.device_names[col_params.default_rank]) {}
device_name(col_params->group.device_names[col_params->default_rank]) {}
/*static*/
int64 CollectiveExecutor::kInvalidId = -1;

View File

@ -132,7 +132,7 @@ struct CollTaskParams {
};
// Unique to a single CollectiveOp node.
struct CollectiveParams {
struct CollectiveParams : public core::RefCounted {
CollGroupParams group;
CollInstanceParams instance;
CollTaskParams task;
@ -298,7 +298,7 @@ class CollectiveExecutor : public core::RefCounted {
virtual void StartAbort(const Status& s) {}
virtual void ExecuteAsync(OpKernelContext* ctx,
const CollectiveParams& col_params,
const CollectiveParams* col_params,
const string& exec_key, StatusCallback done) {
done(errors::Internal(
"A collective Op has been called in a context in which "
@ -367,7 +367,7 @@ struct CollectiveContext {
const DeviceMgr* dev_mgr; // Not owned
OpKernelContext* op_ctx; // Not owned
OpKernelContext::Params* op_params; // Not owned
const CollectiveParams& col_params;
const CollectiveParams* col_params; // Not owned
const string exec_key;
const int64 step_id;
const Tensor* input; // Not owned
@ -380,7 +380,7 @@ struct CollectiveContext {
NcclCommunicatorInterface* nccl_communicator,
const DeviceMgr* dev_mgr, OpKernelContext* ctx,
OpKernelContext::Params* op_params,
const CollectiveParams& col_params, const string& exec_key,
const CollectiveParams* col_params, const string& exec_key,
int64 step_id, const Tensor* input, Tensor* output);
};

View File

@ -61,7 +61,7 @@ Status NcclBase::InitializeCollectiveParams(CollectiveParams* col_params) {
Status NcclBase::InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) {
col_ctx_ = col_ctx;
col_params_ = &col_ctx->col_params;
col_params_ = col_ctx->col_params;
return collective_util::InitializeDeviceAndLocality(
col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
&col_ctx->device_locality);

View File

@ -88,6 +88,9 @@ void NcclReducer::Run(StatusCallback done) {
} else {
done_callback = std::move(done);
}
// Hold a ref to col_params for the rest of this function.
col_params_->Ref();
core::ScopedUnref unref(col_params_);
col_ctx_->nccl_communicator->Enqueue(col_ctx_, std::move(done_callback));
// If no final_op, then this OpKernel is non-blocking.

View File

@ -53,7 +53,9 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
class CollectiveOpV1Kernel : public AsyncOpKernel {
public:
explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
: AsyncOpKernel(c), name_(name()) {}
: AsyncOpKernel(c), name_(name()), col_params_(new CollectiveParams()) {}
~CollectiveOpV1Kernel() override { col_params_->Unref(); }
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
@ -88,28 +90,31 @@ class CollectiveOpV1Kernel : public AsyncOpKernel {
// A string encoding instance, frame and iter to be handed off to
// the implementation for use in generating RecvBuf keys.
string GetCollectiveKey(OpKernelContext* c) {
return CollectiveKey(c, col_params_.group.group_key,
col_params_.instance.instance_key);
return CollectiveKey(c, col_params_->group.group_key,
col_params_->instance.instance_key);
}
// Returns false if calling invocation of ComputeAsync should return
// immediately.
bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
const DoneCallback& done) {
if (col_params_.group.group_size > col_params_.group.device_names.size()) {
if (col_params_->group.group_size >
col_params_->group.device_names.size()) {
// This is the first invocation: Finish initializing col_params_.
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
// blocking work because it's not guaranteed that this call cannot block.
c->collective_executor()->RunClosure([this, c, done, col_exec]() {
col_params_->Ref();
c->collective_executor()->RunClosure([this, c, col_exec, done]() {
VLOG(1) << "CollectiveOpKernel CompleteParams for collective "
<< col_params_.name << " device " << c->device()->name()
<< " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key;
<< col_params_->name << " device " << c->device()->name()
<< " group " << col_params_->group.group_key << " instance "
<< col_params_->instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->attributes(), &col_params_, c->cancellation_manager(),
c->device()->attributes(), col_params_, c->cancellation_manager(),
[this, c, done](const Status& s) {
core::ScopedUnref unref(col_params_);
if (s.ok()) {
col_params_.instance.impl_details.dependencies = dependencies_;
col_params_->instance.impl_details.dependencies = dependencies_;
ComputeAsync(c, done);
} else {
c->SetStatus(s);
@ -128,7 +133,7 @@ class CollectiveOpV1Kernel : public AsyncOpKernel {
DoneCallback done) = 0;
string name_;
CollectiveParams col_params_;
CollectiveParams* col_params_;
std::vector<int32> dependencies_;
};
@ -136,25 +141,25 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = GATHER_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
col_params_->instance.type = GATHER_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
c, col_params_->group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
col_params_->group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
&col_params_->instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
&col_params_->instance.impl_details.timeout_seconds));
const NodeDef& real_node = c->def();
col_params_.name = strings::StrCat(real_node.name(), ": Gather");
col_params_.group.device_type = c->device_type();
col_params_->name = strings::StrCat(real_node.name(), ": Gather");
col_params_->group.device_type = c->device_type();
}
protected:
@ -162,8 +167,8 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
DoneCallback done) override {
auto output_shape = c->input(0).shape();
output_shape.set_dim(
0, output_shape.dim_size(0) * col_params_.group.group_size);
col_params_.instance.shape = output_shape;
0, output_shape.dim_size(0) * col_params_->group.group_size);
col_params_->instance.shape = output_shape;
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
@ -173,24 +178,24 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
// Allocate the output tensor.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->allocate_output(0, col_params_.instance.shape, &output), done);
c, c->allocate_output(0, col_params_->instance.shape, &output), done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, group_key = col_params_.group.group_key,
instance_key = col_params_.instance.instance_key,
done](const Status& s) {
auto actual_done = [c, col_params = col_params_, done](const Status& s) {
VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << group_key << " instance " << instance_key
<< " status " << s;
<< " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key << " status " << s;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
col_params->Unref();
};
VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync start for collective "
<< col_params_.name << " device " << c->device()->name()
<< " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key;
<< col_params_->name << " device " << c->device()->name()
<< " group " << col_params_->group.group_key << " instance "
<< col_params_->instance.instance_key;
col_params_->Ref();
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
@ -207,18 +212,18 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = REDUCTION_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
col_params_->instance.type = REDUCTION_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
c, col_params_->group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
col_params_->group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
OP_REQUIRES_OK(
c, c->GetAttr("subdiv_offsets",
&col_params_.instance.impl_details.subdiv_offsets));
&col_params_->instance.impl_details.subdiv_offsets));
string merge_op_name;
OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
if (merge_op_name == "Max") {
@ -232,24 +237,26 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
errors::InvalidArgument(
"final_op must be one of {\"Id\", \"Div\"} but got ",
final_op_name));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("wait_for", &dependencies_));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
&col_params_->instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
VLOG(2) << "CollectiveReduce instance " << col_params_.instance.instance_key
<< " merge_op " << merge_op_name << " final_op " << final_op_name
&col_params_->instance.impl_details.timeout_seconds));
VLOG(2) << "CollectiveReduce instance "
<< col_params_->instance.instance_key << " merge_op "
<< merge_op_name << " final_op " << final_op_name
<< " communication_hint "
<< col_params_.instance.impl_details.communication_hint
<< " timeout " << col_params_.instance.impl_details.timeout_seconds;
<< col_params_->instance.impl_details.communication_hint
<< " timeout "
<< col_params_->instance.impl_details.timeout_seconds;
const NodeDef& real_node = c->def();
col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
col_params_->name = strings::StrCat(real_node.name(), ": Reduce(",
merge_op_name, ",", final_op_name, ")");
col_params_.group.device_type = c->device_type();
col_params_->group.device_type = c->device_type();
// Find the OpKernels by name, type and device type.
NodeDef sub_node;
@ -257,12 +264,12 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
sub_node.add_input(real_node.input(0));
sub_node.add_input(real_node.input(0));
sub_node.set_device(real_node.device());
SetAttrValue(col_params_.instance.data_type,
SetAttrValue(col_params_->instance.data_type,
&(*sub_node.mutable_attr())["T"]);
merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
col_params_.merge_op = merge_op_.get();
col_params_.final_op = final_op_.get();
col_params_->merge_op = merge_op_.get();
col_params_->final_op = final_op_.get();
}
protected:
@ -279,24 +286,24 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
c->forward_input_or_allocate_output(
{0}, 0, c->input(0).shape(), &output),
done);
col_params_.instance.shape = c->input(0).shape();
col_params_->instance.shape = c->input(0).shape();
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, group_key = col_params_.group.group_key,
instance_key = col_params_.instance.instance_key,
done](const Status& s) {
auto actual_done = [c, col_params = col_params_, done](const Status& s) {
VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << group_key << " instance " << instance_key
<< " status " << s;
<< " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key << " status " << s;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
col_params->Unref();
};
VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync start for collective "
<< col_params_.name << " device " << c->device()->name()
<< " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key;
<< col_params_->name << " device " << c->device()->name()
<< " group " << col_params_->group.group_key << " instance "
<< col_params_->instance.instance_key;
col_params_->Ref();
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
@ -315,29 +322,29 @@ class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
col_params_->instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
c, col_params_->group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
col_params_->group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
&col_params_->instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
col_params_.is_source = true;
col_params_.instance.impl_details.subdiv_offsets = {0};
&col_params_->instance.impl_details.timeout_seconds));
col_params_->is_source = true;
col_params_->instance.impl_details.subdiv_offsets = {0};
col_params_.name =
strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
col_params_.group.device_type = c->device_type();
col_params_->name =
strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
col_params_->group.device_type = c->device_type();
}
protected:
@ -352,30 +359,30 @@ class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(c,
c->forward_input_or_allocate_output(
{0}, 0, col_params_.instance.shape, &output),
{0}, 0, col_params_->instance.shape, &output),
done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
OP_REQUIRES_ASYNC(
c, col_params_.instance.shape.IsSameSize(c->input(0).shape()),
errors::Internal("Declared shape of op ", col_params_.name,
c, col_params_->instance.shape.IsSameSize(c->input(0).shape()),
errors::Internal("Declared shape of op ", col_params_->name,
" does not match shape of input"),
done);
auto actual_done = [c, group_key = col_params_.group.group_key,
instance_key = col_params_.instance.instance_key,
done](const Status& s) {
auto actual_done = [c, col_params = col_params_, done](const Status& s) {
VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << group_key << " instance " << instance_key
<< " status " << s;
<< " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key << " status " << s;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
col_params->Unref();
};
VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync start for collective "
<< col_params_.name << " device " << c->device()->name()
<< " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key;
<< col_params_->name << " device " << c->device()->name()
<< " group " << col_params_->group.group_key << " instance "
<< col_params_->instance.instance_key;
col_params_->Ref();
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
@ -392,29 +399,29 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
col_params_->instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
c, col_params_->group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
col_params_->group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
&col_params_->instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
col_params_.is_source = false;
col_params_.instance.impl_details.subdiv_offsets = {0};
&col_params_->instance.impl_details.timeout_seconds));
col_params_->is_source = false;
col_params_->instance.impl_details.subdiv_offsets = {0};
col_params_.name =
strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
col_params_.group.device_type = c->device_type();
col_params_->name =
strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
col_params_->group.device_type = c->device_type();
}
protected:
@ -428,24 +435,24 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
// No input, so must allocate output.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->allocate_output(0, col_params_.instance.shape, &output), done);
c, c->allocate_output(0, col_params_->instance.shape, &output), done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, group_key = col_params_.group.group_key,
instance_key = col_params_.instance.instance_key,
done](const Status& s) {
auto actual_done = [c, col_params = col_params_, done](const Status& s) {
VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << group_key << " instance_key " << instance_key
<< " status " << s;
<< " group " << col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key << " status " << s;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
col_params->Unref();
};
VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync start for collective "
<< col_params_.name << " device " << c->device()->name()
<< " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key;
<< col_params_->name << " device " << c->device()->name()
<< " group " << col_params_->group.group_key << " instance "
<< col_params_->instance.instance_key;
col_params_->Ref();
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
@ -534,8 +541,8 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
<< col_params->instance.instance_key;
auto done_with_cleanup = [col_params, done = std::move(done)]() {
delete col_params;
done();
col_params->Unref();
};
// Allocate the output tensor, trying to reuse the input.
@ -577,7 +584,7 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
<< " group " << col_params->group.group_key
<< " instance " << col_params->instance.instance_key;
col_exec->ExecuteAsync(
c, *col_params,
c, col_params,
CollectiveKey(c, col_params->group.group_key,
col_params->instance.instance_key),
actual_done);
@ -673,8 +680,8 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
col_params->instance.shape = output_shape;
auto done_with_cleanup = [col_params, done = std::move(done)]() {
delete col_params;
done();
col_params->Unref();
};
Tensor* output = nullptr;
@ -714,7 +721,7 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
<< " group " << col_params->group.group_key
<< " instance " << col_params->instance.instance_key;
col_exec->ExecuteAsync(
c, *col_params,
c, col_params,
CollectiveKey(c, col_params->group.group_key,
col_params->instance.instance_key),
actual_done);
@ -797,8 +804,8 @@ class CollectiveBcastSendV2OpKernel : public AsyncOpKernel {
<< col_params->instance.instance_key;
auto done_with_cleanup = [col_params, done = std::move(done)]() {
delete col_params;
done();
col_params->Unref();
};
// Allocate the output tensor, trying to reuse the input.
@ -840,7 +847,7 @@ class CollectiveBcastSendV2OpKernel : public AsyncOpKernel {
<< " group " << col_params->group.group_key
<< " instance " << col_params->instance.instance_key;
col_exec->ExecuteAsync(
c, *col_params,
c, col_params,
CollectiveKey(c, col_params->group.group_key,
col_params->instance.instance_key),
actual_done);
@ -905,8 +912,8 @@ class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel {
auto col_params = new CollectiveParams();
auto done_with_cleanup = [col_params, done = std::move(done)]() {
delete col_params;
done();
col_params->Unref();
};
OP_REQUIRES_OK_ASYNC(
@ -969,7 +976,7 @@ class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel {
<< " group " << col_params->group.group_key
<< " instance " << col_params->instance.instance_key;
col_exec->ExecuteAsync(
c, *col_params,
c, col_params,
CollectiveKey(c, col_params->group.group_key,
col_params->instance.instance_key),
actual_done);

View File

@ -69,17 +69,17 @@ std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator() {
void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
StatusCallback done) {
const CollectiveParams& col_params = col_ctx->col_params;
const int num_global_devices = col_params.group.group_size;
const int num_local_devices = col_params.group.num_devices_per_task.at(
col_params.group.task_names[col_params.default_rank]);
const CollectiveParams* col_params = col_ctx->col_params;
const int num_global_devices = col_params->group.group_size;
const int num_local_devices = col_params->group.num_devices_per_task.at(
col_params->group.task_names[col_params->default_rank]);
const string nccl_collective_key =
NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id);
auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream();
auto* gpu_info = col_ctx->op_ctx->device()->tensorflow_gpu_device_info();
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info, col_ctx->input,
col_ctx->output, col_ctx->col_params.default_rank,
col_ctx->output, col_ctx->col_params->default_rank,
/*done_callback=*/nullptr);
CancellationManager* cancel_mgr = col_ctx->op_ctx->cancellation_manager();
if (cancel_mgr == nullptr) {
@ -105,15 +105,24 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
}
NcclManager::Context context(
nccl_collective_key, num_local_devices, num_global_devices,
col_params.group.runtime_details.communicator_key,
col_params.source_rank);
VLOG(1) << "NcclCommunicator::Enqueue type " << col_params.instance.type
<< " num_tasks " << col_params.group.num_tasks << " current task "
<< col_params.group.task_names[col_params.default_rank]
col_params->group.runtime_details.communicator_key,
col_params->source_rank);
VLOG(1) << "NcclCommunicator::Enqueue type " << col_params->instance.type
<< " num_tasks " << col_params->group.num_tasks << " current task "
<< col_params->group.task_names[col_params->default_rank]
<< " num local devices " << num_local_devices
<< " num global devices " << num_global_devices << " device "
<< col_ctx->device_name << " instance "
<< col_params.instance.instance_key;
<< col_params->instance.instance_key;
// Hold a ref to col_params for the rest of this function.
// NOTE: an alternate design can be one in which CollectiveParams is not
// refcounted. In such a design, we would need to ensure that the
// done_callback of each participant is called only after this function is
// done with accessing the params. This would likely require some
// coordination mechanism, and may even require the participant thread to
// block until after UnblockDependencies is called below.
col_params->Ref();
core::ScopedUnref unref(col_params);
// `AddTo*` performs consistency checks for the NCCL call and enqueues the
// `Participant` struct locally. When all local participants with this
// `nccl_collective_key` have called `AddToAllReduce` and
@ -123,10 +132,11 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
// The `NcclManager` uses a dedicated CUDA stream for NCCL kernels. At this
// point, it synchronizes the NCCL stream with the compute stream, and then
// enqueues the NCCL kernel on the NCCL stream.
switch (col_params.instance.type) {
switch (col_params->instance.type) {
case REDUCTION_COLLECTIVE: {
ncclRedOp_t reduction_op;
Status s = ReductionOp(col_params.merge_op->type_string(), &reduction_op);
Status s =
ReductionOp(col_params->merge_op->type_string(), &reduction_op);
if (!s.ok()) {
participant->done_callback(s);
return;
@ -140,7 +150,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
break;
}
case BROADCAST_COLLECTIVE: {
if (col_params.is_source) {
if (col_params->is_source) {
nccl_manager_.AddBroadcastSend(std::move(participant), context);
} else {
nccl_manager_.AddBroadcastRecv(std::move(participant), context);
@ -149,7 +159,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
}
default: {
participant->done_callback(errors::Internal("Unexpected CollectiveType ",
col_params.instance.type));
col_params->instance.type));
return;
}
}
@ -175,7 +185,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
// ready to go.
profiler::TraceMe activity("WaitForDependencies",
profiler::TraceMeLevel::kInfo);
col_ctx->col_exec->WaitForDependencies(col_params);
col_ctx->col_exec->WaitForDependencies(*col_params);
nccl_manager_.SignalMultiNodeReady(nccl_collective_key);
}
{
@ -184,7 +194,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
// implementation of `UnblockDependencies` keeps track of the number of
// devices that have launched.
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
col_ctx->col_exec->UnblockDependencies(col_params);
col_ctx->col_exec->UnblockDependencies(*col_params);
}
}