From 18eaf4e8f1769d2bb05ed354a8b6198e49aadfcc Mon Sep 17 00:00:00 2001 From: Ayush Dubey <ayushd@google.com> Date: Thu, 4 Feb 2021 09:40:21 -0800 Subject: [PATCH] Ensure that `CollectiveParams` outlives all references to it. Before this change, it was possible to access a `const CollectiveParams&` after it was destroyed. For example, the call to `UnblockDependencies` in `NcclCommunicator::Enqueue` raced with the done_callback of the collective participant. This change makes `CollectiveParams` a refcounted object, and holds references everywhere it may be accessed. PiperOrigin-RevId: 355646163 Change-Id: I7fd164afe8c1c9aa1c3b77a988930a0624977c7c --- .../base_collective_executor.cc | 16 +- .../common_runtime/base_collective_executor.h | 2 +- .../collective_param_resolver_local.cc | 16 +- .../collective_param_resolver_local.h | 6 +- .../collective_param_resolver_local_test.cc | 129 +++++---- .../hierarchical_tree_broadcaster.cc | 2 +- .../hierarchical_tree_broadcaster_test.cc | 229 ++++++++-------- tensorflow/core/common_runtime/permuter.cc | 2 +- .../core/common_runtime/permuter_test.cc | 68 ++--- tensorflow/core/common_runtime/ring_alg.cc | 2 +- .../core/common_runtime/ring_gatherer_test.cc | 141 +++++----- .../core/common_runtime/ring_reducer_test.cc | 175 +++++++------ .../collective_param_resolver_distributed.cc | 15 +- ...lective_param_resolver_distributed_test.cc | 60 +++-- tensorflow/core/framework/collective.cc | 4 +- tensorflow/core/framework/collective.h | 8 +- tensorflow/core/kernels/collective_nccl.cc | 2 +- .../core/kernels/collective_nccl_reducer.cc | 3 + tensorflow/core/kernels/collective_ops.cc | 247 +++++++++--------- .../core/nccl/collective_communicator.cc | 44 ++-- 20 files changed, 627 insertions(+), 544 deletions(-) diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index a2cfce1111c..c365cddae2d 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -264,7 +264,7 @@ Status BaseCollectiveExecutor::GetStatus(const Status& s) { } void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, - const CollectiveParams& col_params, + const CollectiveParams* col_params, const string& exec_key, StatusCallback done) { // See CompleteParamsAsync() how done() and the timeout callback interacts. @@ -281,7 +281,7 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, } }; auto timeout_microseconds = static_cast<int64>( - col_params.instance.impl_details.timeout_seconds * 1'000'000); + col_params->instance.impl_details.timeout_seconds * 1'000'000); if (timeout_microseconds > 0) { // TODO(xldrx): Share the timeout watchdog thread among collectives. SchedNonBlockingClosureAfter( @@ -297,15 +297,15 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, } Tensor* output = ctx->mutable_output(0); - const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE || - col_params.instance.type == GATHER_COLLECTIVE || - col_params.instance.type == PERMUTE_COLLECTIVE || - (col_params.instance.type == BROADCAST_COLLECTIVE && - col_params.is_source)) + const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE || + col_params->instance.type == GATHER_COLLECTIVE || + col_params->instance.type == PERMUTE_COLLECTIVE || + (col_params->instance.type == BROADCAST_COLLECTIVE && + col_params->is_source)) ? &ctx->input(0) : nullptr; CollectiveImplementationInterface* col_impl = nullptr; - Status status = CreateCollective(col_params, &col_impl); + Status status = CreateCollective(*col_params, &col_impl); if (!status.ok()) { done_safe(status); DCHECK_EQ(nullptr, col_impl); diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h index 142c825df55..8dd0a55ef18 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.h +++ b/tensorflow/core/common_runtime/base_collective_executor.h @@ -110,7 +110,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor { void StartAbort(const Status& s) override TF_LOCKS_EXCLUDED(status_mu_); - void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params, + void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams* col_params, const string& exec_key, StatusCallback done) override; void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index c42b6a61f57..9bc3b2e1b40 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -513,14 +513,14 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device, void CollectiveParamResolverLocal::InitInstanceSharedParams( const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) { - ir->shared.instance = cp->instance; - ir->shared.default_rank = -1; + ir->shared->instance = cp->instance; + ir->shared->default_rank = -1; // Set is_local and task_names in *shared prior to invoking // GetDeviceAttributesAsync. In a distributed context this function can be // called by a derived class, some of the devices may be non-local and // GetDeviceAttributesAsync will use those fields to launch RPCs. - CompleteTaskIsLocal(task_name_, &ir->shared); + CompleteTaskIsLocal(task_name_, ir->shared); } // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks @@ -662,11 +662,11 @@ void CollectiveParamResolverLocal::CompleteInstanceLocal( if (!created_irec) { // Check that the preexisting IRec is consistent with the params passed into // this invocation. - if (ir->shared.instance.type != cp->instance.type || - ir->shared.instance.data_type != cp->instance.data_type) { + if (ir->shared->instance.type != cp->instance.type || + ir->shared->instance.data_type != cp->instance.data_type) { done(errors::Internal("Collective instance ", cp->instance.instance_key, - " expected type ", ir->shared.instance.type, - " and data_type ", ir->shared.instance.data_type, + " expected type ", ir->shared->instance.type, + " and data_type ", ir->shared->instance.data_type, " but got type ", cp->instance.type, " and data_type ", cp->instance.data_type)); return; @@ -686,7 +686,7 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec( status = ir->status; if (status.ok()) { // custom operator= does a deep copy. - cp->instance = ir->shared.instance; + cp->instance = ir->shared->instance; } } if (!status.ok()) { diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index 63a3bf2e063..5a7cd54a0de 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -98,7 +98,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { struct InstanceRec { mutex mu; // Values to be shared by all instances, constant after initialization. - CollectiveParams shared; + CollectiveParams* shared; // If an error occurs during initialization this structure stays in the // table with a non-OK status. Purging the table and restarting needs to be // done at a higher level. @@ -113,7 +113,9 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { std::vector<bool> known TF_GUARDED_BY(mu); std::vector<IRConsumer> known_waiters TF_GUARDED_BY(mu); - InstanceRec() : source_rank(-1), known_count(0) {} + InstanceRec() + : shared(new CollectiveParams()), source_rank(-1), known_count(0) {} + ~InstanceRec() { shared->Unref(); } }; // Find the InstanceRec with the same instance_key as cp. If it doesn't diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc index e1ac46f2e53..611d6bbff50 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc @@ -161,11 +161,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { } TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) { - CollectiveParams cps[NUM_DEVS]; + CollectiveParams* cps[NUM_DEVS]; Status statuses[NUM_DEVS]; Notification note[NUM_DEVS]; for (int i = 0; i < NUM_DEVS; ++i) { - CollectiveParams* cp = &cps[i]; + cps[i] = new CollectiveParams(); + CollectiveParams* cp = cps[i]; cp->group.group_key = 1; cp->group.group_size = 3; cp->group.device_type = DeviceType("CPU"); @@ -192,17 +193,18 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) { } for (int i = 0; i < NUM_DEVS; ++i) { TF_ASSERT_OK(statuses[i]); - ASSERT_EQ(cps[i].group.device_names.size(), 3); + ASSERT_EQ(cps[i]->group.device_names.size(), 3); for (int j = 0; j < NUM_DEVS; ++j) { EXPECT_EQ( strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j), - cps[i].group.device_names[j]); - EXPECT_TRUE(cps[i].task.is_local[j]); + cps[i]->group.device_names[j]); + EXPECT_TRUE(cps[i]->task.is_local[j]); } - EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank.size(), 0); - EXPECT_FALSE(cps[i].is_source); - EXPECT_EQ(cps[i].default_rank, i); - EXPECT_TRUE(cps[i].group.same_num_devices_per_task); + EXPECT_EQ(cps[i]->instance.impl_details.subdiv_source_rank.size(), 0); + EXPECT_FALSE(cps[i]->is_source); + EXPECT_EQ(cps[i]->default_rank, i); + EXPECT_TRUE(cps[i]->group.same_num_devices_per_task); + cps[i]->Unref(); } } @@ -223,11 +225,12 @@ void InitializeCollectiveParamsForBroadcast(int instance_key, int device_idx, TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) { constexpr int kInstanceKey = 5; - CollectiveParams cps[NUM_DEVS]; + CollectiveParams* cps[NUM_DEVS]; Status statuses[NUM_DEVS]; Notification note[NUM_DEVS]; for (int i = 0; i < NUM_DEVS; ++i) { - CollectiveParams* cp = &cps[i]; + cps[i] = new CollectiveParams(); + CollectiveParams* cp = cps[i]; InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp); Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() { string device = @@ -245,16 +248,17 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) { } for (int i = 0; i < NUM_DEVS; ++i) { TF_ASSERT_OK(statuses[i]); - ASSERT_EQ(cps[i].group.device_names.size(), 3); + ASSERT_EQ(cps[i]->group.device_names.size(), 3); for (int j = 0; j < NUM_DEVS; ++j) { EXPECT_EQ( strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j), - cps[i].group.device_names[j]); - EXPECT_TRUE(cps[i].task.is_local[j]); + cps[i]->group.device_names[j]); + EXPECT_TRUE(cps[i]->task.is_local[j]); } - EXPECT_EQ(cps[i].is_source, (i == 1)); - EXPECT_EQ(cps[i].default_rank, i); - EXPECT_TRUE(cps[i].group.same_num_devices_per_task); + EXPECT_EQ(cps[i]->is_source, (i == 1)); + EXPECT_EQ(cps[i]->default_rank, i); + EXPECT_TRUE(cps[i]->group.same_num_devices_per_task); + cps[i]->Unref(); } } @@ -263,11 +267,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) { // get an internal error from param resolution. TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) { constexpr int kInstanceKey = 8; - CollectiveParams cps[NUM_DEVS]; + CollectiveParams* cps[NUM_DEVS]; Status statuses[NUM_DEVS]; Notification note[NUM_DEVS]; for (int i = 0; i < NUM_DEVS; ++i) { - CollectiveParams* cp = &cps[i]; + cps[i] = new CollectiveParams(); + CollectiveParams* cp = cps[i]; InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp); Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() { string device = @@ -291,27 +296,28 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) { " found no source for broadcast. This could mean that there" " were group_size=", NUM_DEVS, " BcastRecvs but no BcastSend.")); + cps[i]->Unref(); } } -CollectiveParams MakeCollectiveParams(int group_key, int instance_key, - bool is_source) { - CollectiveParams cp; - cp.group.group_key = group_key; - cp.group.group_size = NUM_DEVS; - cp.group.device_type = DeviceType("CPU"); - cp.group.num_tasks = 1; - cp.instance.instance_key = instance_key; +CollectiveParams* MakeCollectiveParams(int group_key, int instance_key, + bool is_source) { + auto* cp = new CollectiveParams(); + cp->group.group_key = group_key; + cp->group.group_size = NUM_DEVS; + cp->group.device_type = DeviceType("CPU"); + cp->group.num_tasks = 1; + cp->instance.instance_key = instance_key; // CompleteInstanceLocal only waits for the group for broadcasts. // Testing with broadcasts yields better coverage. - cp.instance.type = BROADCAST_COLLECTIVE; - cp.is_source = is_source; + cp->instance.type = BROADCAST_COLLECTIVE; + cp->is_source = is_source; return cp; } TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) { CancellationManager cancel_mgr; - std::vector<CollectiveParams> cp(NUM_DEVS - 1); + std::vector<CollectiveParams*> cp(NUM_DEVS - 1); BlockingCounter start(NUM_DEVS - 1); BlockingCounter done(NUM_DEVS - 1); for (int i = 0; i < NUM_DEVS - 1; ++i) { @@ -320,11 +326,12 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], - &cancel_mgr, [&done](const Status& s) { + prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr, + [&done, cp = cp[i]](const Status& s) { EXPECT_EQ(s.code(), error::ABORTED); EXPECT_EQ(s.error_message(), "__aborted__"); done.DecrementCount(); + cp->Unref(); }); start.DecrementCount(); }); @@ -336,7 +343,7 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) { TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) { CancellationManager cancel_mgr; - std::vector<CollectiveParams> cp(NUM_DEVS); + std::vector<CollectiveParams*> cp(NUM_DEVS); int group_key = 100; int instance_key = 100; // First do a normal CompleteParamsAsync to complete the group; @@ -349,10 +356,12 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); cp[i] = MakeCollectiveParams(group_key, instance_key, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], - &cancel_mgr, [&done](const Status& s) { + prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], + &cancel_mgr, + [&done, cp = cp[i]](const Status& s) { EXPECT_EQ(s.code(), error::OK); done.DecrementCount(); + cp->Unref(); }); }); } @@ -361,21 +370,21 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) { BlockingCounter start(NUM_DEVS - 1); BlockingCounter done(NUM_DEVS - 1); for (int i = 0; i < NUM_DEVS - 1; ++i) { - Env::Default()->SchedClosure( - [this, group_key, instance_key, i, &cancel_mgr, &cp, &start, &done] { - string device = - strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); - cp[i] = MakeCollectiveParams(group_key, instance_key + 1, - /*is_source*/ i == 0); - prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], - &cancel_mgr, [&done](const Status& s) { - EXPECT_EQ(s.code(), error::ABORTED); - EXPECT_EQ(s.error_message(), - "__aborted__"); - done.DecrementCount(); - }); - start.DecrementCount(); - }); + Env::Default()->SchedClosure([this, group_key, instance_key, i, &cancel_mgr, + &cp, &start, &done] { + string device = + strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); + cp[i] = MakeCollectiveParams(group_key, instance_key + 1, + /*is_source*/ i == 0); + prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr, + [&done, cp = cp[i]](const Status& s) { + EXPECT_EQ(s.code(), error::ABORTED); + EXPECT_EQ(s.error_message(), "__aborted__"); + done.DecrementCount(); + cp->Unref(); + }); + start.DecrementCount(); + }); } start.Wait(); prl_->StartAbort(Status(error::ABORTED, "__aborted__")); @@ -388,7 +397,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) { int instance_key = 100; // First do a normal CompleteParamsAsync to complete the group; { - std::vector<CollectiveParams> cp(NUM_DEVS); + std::vector<CollectiveParams*> cp(NUM_DEVS); BlockingCounter done(NUM_DEVS); for (int i = 0; i < NUM_DEVS; ++i) { Env::Default()->SchedClosure([this, group_key, instance_key, i, @@ -397,10 +406,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); cp[i] = MakeCollectiveParams(group_key, instance_key, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i], - &cancel_mgr, [&done](const Status& s) { + prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], + &cancel_mgr, + [&done, cp = cp[i]](const Status& s) { EXPECT_EQ(s.code(), error::OK); done.DecrementCount(); + cp->Unref(); }); }); } @@ -411,9 +422,10 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) { auto complete_params = [this, &cancel_mgr](int group_key, int instance_key) { string device = "/job:localhost/replica:0/task:0/device:CPU:0"; Notification done; - auto cp = MakeCollectiveParams(group_key, instance_key, - /*is_source*/ true); - prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, &cancel_mgr, + auto* cp = MakeCollectiveParams(group_key, instance_key, + /*is_source*/ true); + core::ScopedUnref unref(cp); + prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, &cancel_mgr, [&done](const Status& s) { EXPECT_EQ(s.code(), error::ABORTED); EXPECT_EQ(s.error_message(), "__aborted__"); @@ -449,16 +461,17 @@ TEST_F(CollectiveParamResolverLocalTest, AbortNormalCompleteParamsAsync) { while (true) { Status status; Notification n; - auto cp = + auto* cp = MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key, /*is_source*/ i == 0); - prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, + prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, &cancel_mgr, [&status, &n](const Status& s) { status = s; n.Notify(); }); n.WaitForNotification(); + cp->Unref(); // The status should be either OK or the aborted status. if (!status.ok()) { EXPECT_EQ(status.code(), error::ABORTED); diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc index e78fbef13de..ebe568d6bac 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc @@ -188,7 +188,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveContext( std::shared_ptr<CollectiveContext> col_ctx) { CHECK(col_ctx->dev_mgr); col_ctx_ = col_ctx; - col_params_ = &col_ctx->col_params; + col_params_ = col_ctx->col_params; return collective_util::InitializeDeviceAndLocality( col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device, &col_ctx->device_locality); diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc index 97a1d0b46ce..378dc459da1 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc @@ -56,23 +56,24 @@ class TrivialTest : public ::testing::Test { // R = tested rank // RF = receive-from rank // ST = send_to rank vector -#define DEF_TL_TEST(D, S, R, RF, ST) \ - TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \ - CollectiveParams cp; \ - cp.group.group_size = D; \ - cp.instance.impl_details.subdiv_source_rank = {S}; \ - cp.instance.impl_details.subdiv_permutations.push_back( \ - std::vector<int>(D, 0)); \ - cp.subdiv_rank = {R}; \ - cp.is_source = (S == R); \ - EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(cp, 0)); \ - std::vector<int> expected = ST; \ - std::vector<int> send_to; \ - HierarchicalTreeBroadcaster::TreeSendTo(cp, 0, &send_to); \ - ASSERT_EQ(expected.size(), send_to.size()); \ - for (int i = 0; i < expected.size(); ++i) { \ - EXPECT_EQ(expected[i], send_to[i]); \ - } \ +#define DEF_TL_TEST(D, S, R, RF, ST) \ + TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \ + auto* cp = new CollectiveParams(); \ + core::ScopedUnref unref(cp); \ + cp->group.group_size = D; \ + cp->instance.impl_details.subdiv_source_rank = {S}; \ + cp->instance.impl_details.subdiv_permutations.push_back( \ + std::vector<int>(D, 0)); \ + cp->subdiv_rank = {R}; \ + cp->is_source = (S == R); \ + EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(*cp, 0)); \ + std::vector<int> expected = ST; \ + std::vector<int> send_to; \ + HierarchicalTreeBroadcaster::TreeSendTo(*cp, 0, &send_to); \ + ASSERT_EQ(expected.size(), send_to.size()); \ + for (int i = 0; i < expected.size(); ++i) { \ + EXPECT_EQ(expected[i], send_to[i]); \ + } \ } #define V(...) std::vector<int>({__VA_ARGS__}) @@ -196,12 +197,14 @@ class FailTestRMA : public CollectiveRemoteAccessLocal { class HierarchicalTreeBroadcasterTest : public ::testing::Test { protected: - HierarchicalTreeBroadcasterTest() : device_type_(DEVICE_CPU) {} + HierarchicalTreeBroadcasterTest() + : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {} ~HierarchicalTreeBroadcasterTest() override { stop_ = true; for (auto i : instances_) delete i; if (col_exec_) col_exec_->Unref(); + if (col_params_) col_params_->Unref(); } #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -262,30 +265,31 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get(), work_queue_); - col_params_.name = "test_collective"; - col_params_.instance.data_type = dtype; + col_params_ = new CollectiveParams(); + col_params_->name = "test_collective"; + col_params_->instance.data_type = dtype; static const int kGroupKey = 6; - col_params_.group.group_key = kGroupKey; + col_params_->group.group_key = kGroupKey; static const int kInstanceKey = 18; - col_params_.instance.instance_key = kInstanceKey; - col_params_.group.device_type = device_type; - col_params_.group.group_size = num_workers * num_devices_per_worker; - col_params_.instance.impl_details.subdiv_offsets.clear(); - col_params_.instance.type = BROADCAST_COLLECTIVE; + col_params_->instance.instance_key = kInstanceKey; + col_params_->group.device_type = device_type; + col_params_->group.group_size = num_workers * num_devices_per_worker; + col_params_->instance.impl_details.subdiv_offsets.clear(); + col_params_->instance.type = BROADCAST_COLLECTIVE; int num_subdivs = num_workers + (num_workers > 1 ? 1 : 0); VLOG(2) << "#subdiv=" << num_subdivs; - col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs); - col_params_.subdiv_rank.resize(num_subdivs); + col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs); + col_params_->subdiv_rank.resize(num_subdivs); // Inter-machine broadcast. int subdiv_i = 0; if (num_workers > 1) { - col_params_.instance.impl_details.subdiv_permutations[subdiv_i].resize( + col_params_->instance.impl_details.subdiv_permutations[subdiv_i].resize( total_num_devices, -1); for (int i = 0, rank = 0; i < total_num_devices; i++) { if (i % num_devices_per_worker == 0) { - col_params_.instance.impl_details + col_params_->instance.impl_details .subdiv_permutations[subdiv_i][rank] = i; rank++; } @@ -293,7 +297,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { if (VLOG_IS_ON(2)) { string sp_buf; for (int p : - col_params_.instance.impl_details.subdiv_permutations[subdiv_i]) + col_params_->instance.impl_details.subdiv_permutations[subdiv_i]) strings::StrAppend(&sp_buf, p, ", "); VLOG(2) << "subdiv_i=" << subdiv_i << " perm=" << sp_buf; } @@ -301,22 +305,22 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { } // Intra-machine broadcast. for (int i = 0; subdiv_i < num_subdivs; i++, subdiv_i++) { - col_params_.instance.impl_details.subdiv_permutations[subdiv_i].resize( + col_params_->instance.impl_details.subdiv_permutations[subdiv_i].resize( total_num_devices, -1); int perm_i_base = i * num_devices_per_worker; VLOG(2) << "subdiv_i=" << subdiv_i << " i=" << i << " perm_i_base=" << perm_i_base << " subdiv_perms.size=" - << col_params_.instance.impl_details.subdiv_permutations.size(); + << col_params_->instance.impl_details.subdiv_permutations.size(); // subdiv for worker i. for (int j = perm_i_base, rank = 0; j < perm_i_base + num_devices_per_worker; j++, rank++) { - col_params_.instance.impl_details.subdiv_permutations[subdiv_i][rank] = + col_params_->instance.impl_details.subdiv_permutations[subdiv_i][rank] = j; } if (VLOG_IS_ON(2)) { string sp_buf; for (int p : - col_params_.instance.impl_details.subdiv_permutations[subdiv_i]) + col_params_->instance.impl_details.subdiv_permutations[subdiv_i]) strings::StrAppend(&sp_buf, p, ", "); VLOG(2) << "subdiv_i=" << subdiv_i << " perm=" << sp_buf; } @@ -333,16 +337,16 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { dev_name = strings::StrCat(task_name, "/device:CPU:", di); } VLOG(2) << "dev=" << dev_name; - col_params_.group.device_names.push_back(dev_name); - col_params_.group.task_names.push_back(task_name); - col_params_.task.is_local.push_back(true); + col_params_->group.device_names.push_back(dev_name); + col_params_->group.task_names.push_back(task_name); + col_params_->task.is_local.push_back(true); } } for (int wi = 0; wi < num_workers; wi++) { for (int di = 0; di < num_devices_per_worker; di++) { int default_rank = wi * num_devices_per_worker + di; instances_.push_back(new DeviceInstance( - default_rank, col_params_.group.device_names[default_rank], + default_rank, col_params_->group.device_names[default_rank], device_type, this)); } } @@ -435,7 +439,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { // Copy the expected value from the broadcast source tensor std::vector<T> expected(tensor_len, 0.0); - const CollectiveParams& cp = instances_[0]->col_params_; + const CollectiveParams& cp = *instances_[0]->col_params_; int broadcast_dev_id = cp.instance.impl_details.subdiv_permutations [0][cp.instance.impl_details.subdiv_source_rank[0]]; @@ -558,27 +562,29 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { : parent_(parent), dev_name_(dev_name), device_type_(device_type), - rank_(rank) { + rank_(rank), + col_params_(new CollectiveParams()) { TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_)); - col_params_.name = parent_->col_params_.name; - col_params_.instance.data_type = parent_->col_params_.instance.data_type; - col_params_.group = parent_->col_params_.group; - col_params_.instance.instance_key = - parent_->col_params_.instance.instance_key; - col_params_.task.is_local = parent_->col_params_.task.is_local; - col_params_.instance.impl_details.subdiv_permutations = - parent_->col_params_.instance.impl_details.subdiv_permutations; - col_params_.subdiv_rank = parent_->col_params_.subdiv_rank; + col_params_->name = parent_->col_params_->name; + col_params_->instance.data_type = + parent_->col_params_->instance.data_type; + col_params_->group = parent_->col_params_->group; + col_params_->instance.instance_key = + parent_->col_params_->instance.instance_key; + col_params_->task.is_local = parent_->col_params_->task.is_local; + col_params_->instance.impl_details.subdiv_permutations = + parent_->col_params_->instance.impl_details.subdiv_permutations; + col_params_->subdiv_rank = parent_->col_params_->subdiv_rank; - int group_size = col_params_.group.group_size; - CHECK_EQ(group_size, col_params_.group.device_names.size()); + int group_size = col_params_->group.group_size; + CHECK_EQ(group_size, col_params_->group.device_names.size()); // Default rank is order in device_names. - col_params_.default_rank = rank; + col_params_->default_rank = rank; - auto& impl = col_params_.instance.impl_details; + auto& impl = col_params_->instance.impl_details; size_t num_subdivs = impl.subdiv_permutations.size(); impl.subdiv_source_rank.resize(num_subdivs, 0); - col_params_.subdiv_rank.resize(num_subdivs); + col_params_->subdiv_rank.resize(num_subdivs); for (size_t si = 0; si < num_subdivs; si++) { int perm_rank = -1; for (int i = 0; i < group_size; i++) { @@ -587,18 +593,20 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { break; } } - col_params_.subdiv_rank[si] = perm_rank; + col_params_->subdiv_rank[si] = perm_rank; } string rank_buf; - for (int r : col_params_.subdiv_rank) { + for (int r : col_params_->subdiv_rank) { strings::StrAppend(&rank_buf, r, ", "); } VLOG(1) << "default=" << rank << " subdiv_ranks=" << rank_buf; - col_params_.is_source = - col_params_.subdiv_rank[0] == impl.subdiv_source_rank[0]; + col_params_->is_source = + col_params_->subdiv_rank[0] == impl.subdiv_source_rank[0]; } + ~DeviceInstance() { col_params_->Unref(); } + void InitTensor(DataType dtype, const TensorShape& shape, const InitFunc& f) { tensor_ = @@ -641,22 +649,22 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { op_params.op_device_context = dev_ctx; int forward_from[] = {OpKernelContext::Params::kNeverForward}; if (forward_input) forward_from[0] = 0; - if (col_params_.is_source) { + if (col_params_->is_source) { op_params.forward_from_array = &forward_from[0]; } AllocatorAttributes generic_alloc_attr; op_params.output_attr_array = &generic_alloc_attr; std::unique_ptr<OpKernel> op = - col_params_.is_source - ? parent_->GetCollectiveBcastSend(col_params_, &tensor_, + col_params_->is_source + ? parent_->GetCollectiveBcastSend(*col_params_, &tensor_, DEVICE_CPU, device_) - : parent_->GetCollectiveBcastRecv(col_params_, tensor_.shape(), + : parent_->GetCollectiveBcastRecv(*col_params_, tensor_.shape(), DEVICE_CPU, device_); op_params.op_kernel = op.get(); OpKernelContext ctx(&op_params, 1); Tensor* output_tensor_ptr = nullptr; - if (col_params_.is_source) { + if (col_params_->is_source) { TF_CHECK_OK(ctx.forward_input_or_allocate_output( {0}, 0, tensor_.shape(), &output_tensor_ptr)); } else { @@ -665,11 +673,11 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { } CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0)); const Tensor* input_tensor_ptr = - col_params_.is_source ? &tensor_ : nullptr; + col_params_->is_source ? &tensor_ : nullptr; // Prepare a Broadcaster instance. string exec_key = - strings::StrCat(col_params_.instance.instance_key, ":0:0"); + strings::StrCat(col_params_->instance.instance_key, ":0:0"); HierarchicalTreeBroadcaster* broadcaster = new HierarchicalTreeBroadcaster; core::ScopedUnref unref(broadcaster); @@ -694,7 +702,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { int rank_; Tensor tensor_; Device* device_; - CollectiveParams col_params_; + CollectiveParams* col_params_; Status status_; }; // class DeviceInstance @@ -708,7 +716,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { std::unique_ptr<DeviceResolverLocal> dev_resolver_; std::shared_ptr<UnboundedWorkQueue> work_queue_; std::vector<DeviceInstance*> instances_; - CollectiveParams col_params_; + CollectiveParams* col_params_; std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_; std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_; std::unique_ptr<string> gpu_ring_order_; @@ -720,33 +728,35 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { }; TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams1Task8GPU) { - CollectiveParams cp; - PrepColParamsForSubdivPermsTest(&cp, 1, 8); + auto* cp = new CollectiveParams(); + core::ScopedUnref unref(cp); + PrepColParamsForSubdivPermsTest(cp, 1, 8); // source 0 device 0 - cp.source_rank = 0; - cp.default_rank = 0; - RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0}); + cp->source_rank = 0; + cp->default_rank = 0; + RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0}); // source 2 device 2 - cp.source_rank = 2; - cp.default_rank = 2; - RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2}); + cp->source_rank = 2; + cp->default_rank = 2; + RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2}); // source 2 device 0 - cp.source_rank = 2; - cp.default_rank = 0; - RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2}); + cp->source_rank = 2; + cp->default_rank = 0; + RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2}); } TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) { - CollectiveParams cp; - PrepColParamsForSubdivPermsTest(&cp, 4, 8); + auto* cp = new CollectiveParams(); + core::ScopedUnref unref(cp); + PrepColParamsForSubdivPermsTest(cp, 4, 8); // source 0 device 0 - cp.source_rank = 0; - cp.default_rank = 0; - RunSubdivPermsTest(&cp, + cp->source_rank = 0; + cp->default_rank = 0; + RunSubdivPermsTest(cp, {{0, 8, 16, 24}, {0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10, 11, 12, 13, 14, 15}, @@ -755,9 +765,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) { {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0}); // source 2 device 0 - cp.source_rank = 2; - cp.default_rank = 0; - RunSubdivPermsTest(&cp, + cp->source_rank = 2; + cp->default_rank = 0; + RunSubdivPermsTest(cp, {{2, 8, 16, 24}, {0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10, 11, 12, 13, 14, 15}, @@ -766,9 +776,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) { {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0}); // source 9 device 9 - cp.source_rank = 9; - cp.default_rank = 9; - RunSubdivPermsTest(&cp, + cp->source_rank = 9; + cp->default_rank = 9; + RunSubdivPermsTest(cp, {{0, 9, 16, 24}, {0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10, 11, 12, 13, 14, 15}, @@ -778,28 +788,29 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) { } TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) { - CollectiveParams cp; + auto* cp = new CollectiveParams(); + core::ScopedUnref unref(cp); int num_tasks = 4; - cp.group.device_type = DeviceType("GPU"); - cp.group.num_tasks = num_tasks; - cp.group.group_size = 0; - cp.instance.type = BROADCAST_COLLECTIVE; - cp.instance.impl_details.collective_name = "HierarchicalTreeBroadcast"; + cp->group.device_type = DeviceType("GPU"); + cp->group.num_tasks = num_tasks; + cp->group.group_size = 0; + cp->instance.type = BROADCAST_COLLECTIVE; + cp->instance.impl_details.collective_name = "HierarchicalTreeBroadcast"; std::vector<int> dev_per_task = {4, 4, 6, 8}; - for (int ti = 0; ti < cp.group.num_tasks; ti++) { + for (int ti = 0; ti < cp->group.num_tasks; ti++) { string task_name = strings::StrCat("/job:worker/replica:0/task:", ti); for (int di = 0; di < dev_per_task[ti]; di++) { string dev_name = strings::StrCat(task_name, "/device:GPU:", di); - cp.group.task_names.push_back(task_name); - cp.group.device_names.push_back(dev_name); - cp.group.group_size++; + cp->group.task_names.push_back(task_name); + cp->group.device_names.push_back(dev_name); + cp->group.group_size++; } } // source 0 device 0 - cp.source_rank = 0; - cp.default_rank = 0; - RunSubdivPermsTest(&cp, + cp->source_rank = 0; + cp->default_rank = 0; + RunSubdivPermsTest(cp, {{0, 4, 8, 14}, {0, 1, 2, 3}, {4, 5, 6, 7}, @@ -808,9 +819,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) { {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0}); // source 2 device 0 - cp.source_rank = 2; - cp.default_rank = 0; - RunSubdivPermsTest(&cp, + cp->source_rank = 2; + cp->default_rank = 0; + RunSubdivPermsTest(cp, {{2, 4, 8, 14}, {0, 1, 2, 3}, {4, 5, 6, 7}, @@ -819,9 +830,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) { {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0}); // source 9 device 5 - cp.source_rank = 9; - cp.default_rank = 5; - RunSubdivPermsTest(&cp, + cp->source_rank = 9; + cp->default_rank = 5; + RunSubdivPermsTest(cp, {{0, 4, 9, 14}, {0, 1, 2, 3}, {4, 5, 6, 7}, diff --git a/tensorflow/core/common_runtime/permuter.cc b/tensorflow/core/common_runtime/permuter.cc index 9aee5e5d5c9..c1dcd20dc06 100644 --- a/tensorflow/core/common_runtime/permuter.cc +++ b/tensorflow/core/common_runtime/permuter.cc @@ -54,7 +54,7 @@ Status Permuter::InitializeCollectiveContext( std::shared_ptr<CollectiveContext> col_ctx) { DCHECK(col_ctx->dev_mgr); col_ctx_ = col_ctx; - col_params_ = &col_ctx->col_params; + col_params_ = col_ctx->col_params; return collective_util::InitializeDeviceAndLocality( col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device, &col_ctx->device_locality); diff --git a/tensorflow/core/common_runtime/permuter_test.cc b/tensorflow/core/common_runtime/permuter_test.cc index 10c527ca573..a5f8add6c30 100644 --- a/tensorflow/core/common_runtime/permuter_test.cc +++ b/tensorflow/core/common_runtime/permuter_test.cc @@ -107,12 +107,14 @@ class FailTestRMA : public CollectiveRemoteAccessLocal { class PermuterTest : public ::testing::Test { protected: - PermuterTest() : device_type_(DEVICE_CPU) {} + PermuterTest() + : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {} ~PermuterTest() override { stop_ = true; for (auto i : instances_) delete i; if (col_exec_) col_exec_->Unref(); + if (col_params_) col_params_->Unref(); } #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -170,12 +172,13 @@ class PermuterTest : public ::testing::Test { col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get(), work_queue_); - col_params_.name = "test_collective"; - col_params_.instance.data_type = dtype; + col_params_ = new CollectiveParams(); + col_params_->name = "test_collective"; + col_params_->instance.data_type = dtype; static const int kInstanceKey = 18; - col_params_.instance.instance_key = kInstanceKey; - col_params_.group.device_type = device_type; - col_params_.instance.type = PERMUTE_COLLECTIVE; + col_params_->instance.instance_key = kInstanceKey; + col_params_->group.device_type = device_type; + col_params_->instance.type = PERMUTE_COLLECTIVE; // Set up all the fake device contexts. for (int wi = 0; wi < num_workers; wi++) { @@ -187,12 +190,12 @@ class PermuterTest : public ::testing::Test { } else { dev_name = strings::StrCat(task_name, "/device:CPU:", di); } - col_params_.group.device_names.push_back(dev_name); - col_params_.instance.devices.push_back(dev_name); + col_params_->group.device_names.push_back(dev_name); + col_params_->instance.devices.push_back(dev_name); int default_rank = wi * num_devices_per_worker + di; permutation_.push_back(default_rank); - col_params_.group.task_names.push_back(task_name); - col_params_.task.is_local.push_back(true); + col_params_->group.task_names.push_back(task_name); + col_params_->task.is_local.push_back(true); } } @@ -210,13 +213,13 @@ class PermuterTest : public ::testing::Test { std::next_permutation(permutation_.begin() + i, permutation_.begin() + i + 2); } - col_params_.instance.permutation = permutation_; + col_params_->instance.permutation = permutation_; for (int wi = 0; wi < num_workers; wi++) { for (int di = 0; di < num_devices_per_worker; di++) { int default_rank = wi * num_devices_per_worker + di; instances_.push_back(new DeviceInstance( - default_rank, col_params_.group.device_names[default_rank], + default_rank, col_params_->group.device_names[default_rank], device_type, this)); } } @@ -320,25 +323,30 @@ class PermuterTest : public ::testing::Test { : parent_(parent), dev_name_(dev_name), device_type_(device_type), - rank_(rank) { + rank_(rank), + col_params_(new CollectiveParams()) { TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_)); - col_params_.name = parent_->col_params_.name; - col_params_.instance.data_type = parent_->col_params_.instance.data_type; - col_params_.instance.instance_key = - parent_->col_params_.instance.instance_key; - col_params_.group.device_type = parent_->col_params_.group.device_type; - col_params_.group.device_names = parent_->col_params_.group.device_names; - col_params_.instance.devices = parent_->col_params_.instance.devices; - col_params_.instance.permutation = - parent->col_params_.instance.permutation; - col_params_.group.task_names = parent_->col_params_.group.task_names; - col_params_.task.is_local = parent_->col_params_.task.is_local; - CHECK_EQ(col_params_.instance.devices.size(), - col_params_.group.device_names.size()); + col_params_->name = parent_->col_params_->name; + col_params_->instance.data_type = + parent_->col_params_->instance.data_type; + col_params_->instance.instance_key = + parent_->col_params_->instance.instance_key; + col_params_->group.device_type = parent_->col_params_->group.device_type; + col_params_->group.device_names = + parent_->col_params_->group.device_names; + col_params_->instance.devices = parent_->col_params_->instance.devices; + col_params_->instance.permutation = + parent->col_params_->instance.permutation; + col_params_->group.task_names = parent_->col_params_->group.task_names; + col_params_->task.is_local = parent_->col_params_->task.is_local; + CHECK_EQ(col_params_->instance.devices.size(), + col_params_->group.device_names.size()); // Default rank is order in device_names. - col_params_.default_rank = rank; + col_params_->default_rank = rank; } + ~DeviceInstance() { col_params_->Unref(); } + void InitTensor(DataType dtype, const TensorShape& shape, const InitFunc& f) { tensor_input_ = @@ -387,7 +395,7 @@ class PermuterTest : public ::testing::Test { // Prepare a Permuter instance. string exec_key = - strings::StrCat(col_params_.instance.instance_key, ":0:0"); + strings::StrCat(col_params_->instance.instance_key, ":0:0"); Permuter* permuter = new Permuter; core::ScopedUnref unref(permuter); auto col_ctx = std::make_shared<CollectiveContext>( @@ -412,7 +420,7 @@ class PermuterTest : public ::testing::Test { Tensor tensor_input_; Tensor tensor_output_; Device* device_; - CollectiveParams col_params_; + CollectiveParams* col_params_; Status status_; }; // class DeviceInstance @@ -425,7 +433,7 @@ class PermuterTest : public ::testing::Test { std::unique_ptr<DeviceResolverLocal> dev_resolver_; std::shared_ptr<UnboundedWorkQueue> work_queue_; std::vector<DeviceInstance*> instances_; - CollectiveParams col_params_; + CollectiveParams* col_params_; std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_; std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_; std::unique_ptr<string> gpu_ring_order_; diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc index e664eb90865..a081d2cb730 100644 --- a/tensorflow/core/common_runtime/ring_alg.cc +++ b/tensorflow/core/common_runtime/ring_alg.cc @@ -245,7 +245,7 @@ Status RingAlg::InitializeCollectiveContext( std::shared_ptr<CollectiveContext> col_ctx) { DCHECK(col_ctx->dev_mgr); col_ctx_ = col_ctx; - col_params_ = &col_ctx->col_params; + col_params_ = col_ctx->col_params; return collective_util::InitializeDeviceAndLocality( col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device, &col_ctx->device_locality); diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc index 1f23ee1a8a7..0a6f81a5a2a 100644 --- a/tensorflow/core/common_runtime/ring_gatherer_test.cc +++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc @@ -115,7 +115,8 @@ static int64 kStepId = 123; class RingGathererTest : public ::testing::Test { protected: - RingGathererTest() : device_type_(DEVICE_CPU) {} + RingGathererTest() + : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {} #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM void InitGPUDevices() { @@ -132,6 +133,7 @@ class RingGathererTest : public ::testing::Test { stop_ = true; for (auto i : instances_) delete i; if (col_exec_) col_exec_->Unref(); + if (col_params_) col_params_->Unref(); } void Init(int num_workers, int num_devices, DataType dtype, @@ -180,24 +182,25 @@ class RingGathererTest : public ::testing::Test { col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get(), work_queue_); - col_params_.name = "test_collective"; + col_params_ = new CollectiveParams(); + col_params_->name = "test_collective"; static const int kGroupKey = 5; - col_params_.group.group_key = kGroupKey; - col_params_.group.device_type = device_type; - col_params_.group.group_size = num_workers * num_devices; + col_params_->group.group_key = kGroupKey; + col_params_->group.device_type = device_type; + col_params_->group.group_size = num_workers * num_devices; static const int kInstanceKey = 17; - col_params_.instance.instance_key = kInstanceKey; - col_params_.instance.impl_details.subdiv_offsets.clear(); - col_params_.instance.type = GATHER_COLLECTIVE; - col_params_.instance.impl_details.collective_name = "RingGather"; - col_params_.instance.data_type = dtype; - col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs); - col_params_.subdiv_rank.resize(num_subdivs); + col_params_->instance.instance_key = kInstanceKey; + col_params_->instance.impl_details.subdiv_offsets.clear(); + col_params_->instance.type = GATHER_COLLECTIVE; + col_params_->instance.impl_details.collective_name = "RingGather"; + col_params_->instance.data_type = dtype; + col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs); + col_params_->subdiv_rank.resize(num_subdivs); int subdiv_stride = num_devices / num_subdivs; for (int sdi = 0; sdi < num_subdivs; ++sdi) { - col_params_.instance.impl_details.subdiv_offsets.push_back(sdi * - subdiv_stride); - col_params_.subdiv_rank[sdi] = sdi * subdiv_stride; + col_params_->instance.impl_details.subdiv_offsets.push_back( + sdi * subdiv_stride); + col_params_->subdiv_rank[sdi] = sdi * subdiv_stride; } // Set up a local device ring order that's not just 0,1,2... @@ -225,16 +228,16 @@ class RingGathererTest : public ::testing::Test { dev_name = strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size()); } - col_params_.group.device_names.push_back(dev_name); - col_params_.group.task_names.push_back(task_name); + col_params_->group.device_names.push_back(dev_name); + col_params_->group.task_names.push_back(task_name); // Normally each device would set is_local to its own perspective but // this test runs in a single process so is_local is always true. - col_params_.task.is_local.push_back(true); + col_params_->task.is_local.push_back(true); for (int sdi = 0; sdi < num_subdivs; ++sdi) { int rotated_di = - (di + col_params_.instance.impl_details.subdiv_offsets[sdi]) % + (di + col_params_->instance.impl_details.subdiv_offsets[sdi]) % num_devices; - col_params_.instance.impl_details.subdiv_permutations[sdi].push_back( + col_params_->instance.impl_details.subdiv_permutations[sdi].push_back( wi * num_devices + local_ring_order[rotated_di]); } } @@ -243,7 +246,7 @@ class RingGathererTest : public ::testing::Test { for (int di = 0; di < num_devices; ++di) { int rank = wi * num_devices + di; instances_.push_back(new DeviceInstance( - rank, col_params_.group.device_names[rank], device_type_, this)); + rank, col_params_->group.device_names[rank], device_type_, this)); } } } @@ -387,39 +390,42 @@ class RingGathererTest : public ::testing::Test { : parent_(parent), dev_name_(dev_name), device_type_(device_type), - rank_(rank) { + rank_(rank), + col_params_(new CollectiveParams()) { TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_)) << "Couldn't find device " << dev_name << " existing devices: " << parent_->dev_mgr_->DebugString(); - col_params_.name = parent_->col_params_.name; - col_params_.group = parent_->col_params_.group; - col_params_.instance = parent->col_params_.instance; - col_params_.task.is_local = parent_->col_params_.task.is_local; - col_params_.subdiv_rank = parent_->col_params_.subdiv_rank; + col_params_->name = parent_->col_params_->name; + col_params_->group = parent_->col_params_->group; + col_params_->instance = parent->col_params_->instance; + col_params_->task.is_local = parent_->col_params_->task.is_local; + col_params_->subdiv_rank = parent_->col_params_->subdiv_rank; - int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size()); - int group_size = col_params_.group.group_size; + int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size()); + int group_size = col_params_->group.group_size; CHECK_EQ(group_size, - static_cast<int>(col_params_.group.device_names.size())); + static_cast<int>(col_params_->group.device_names.size())); // Id of this device is at rank position in first subdiv perm. int my_device_id = - col_params_.instance.impl_details.subdiv_permutations[0][rank]; - col_params_.default_rank = my_device_id; + col_params_->instance.impl_details.subdiv_permutations[0][rank]; + col_params_->default_rank = my_device_id; // Set rank for all other subdivs by finding that device_id. for (int sdi = 0; sdi < num_subdivs; ++sdi) { - for (int r = 0; r < static_cast<int>(col_params_.instance.impl_details + for (int r = 0; r < static_cast<int>(col_params_->instance.impl_details .subdiv_permutations[sdi] .size()); ++r) { if (my_device_id == - col_params_.instance.impl_details.subdiv_permutations[sdi][r]) { - col_params_.subdiv_rank[sdi] = r; + col_params_->instance.impl_details.subdiv_permutations[sdi][r]) { + col_params_->subdiv_rank[sdi] = r; break; } } } } + ~DeviceInstance() { col_params_->Unref(); } + void InitTensor(DataType dtype, const TensorShape& shape, const std::function<void(Tensor*)>& init_f) { input_tensor_ = @@ -464,7 +470,7 @@ class RingGathererTest : public ::testing::Test { AllocatorAttributes generic_alloc_attr; op_params.output_attr_array = &generic_alloc_attr; std::unique_ptr<OpKernel> op = parent_->GetCollectiveGather( - col_params_, &input_tensor_, DEVICE_CPU, device_); + *col_params_, &input_tensor_, DEVICE_CPU, device_); op_params.op_kernel = op.get(); OpKernelContext ctx(&op_params, 1); @@ -478,7 +484,7 @@ class RingGathererTest : public ::testing::Test { CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0)); // Prepare a RingGatherer instance. string exec_key = - strings::StrCat(col_params_.instance.instance_key, ":0:0"); + strings::StrCat(col_params_->instance.instance_key, ":0:0"); RingGatherer* gatherer = new RingGatherer; core::ScopedUnref unref(gatherer); auto col_ctx = std::make_shared<CollectiveContext>( @@ -507,7 +513,7 @@ class RingGathererTest : public ::testing::Test { Tensor input_tensor_; Tensor output_tensor_; Device* device_; - CollectiveParams col_params_; + CollectiveParams* col_params_; std::unique_ptr<CollectiveAdapter> ca_; std::unique_ptr<OpKernelContext> ctx_; Status status_; @@ -521,7 +527,7 @@ class RingGathererTest : public ::testing::Test { std::unique_ptr<DeviceResolverLocal> dev_resolver_; std::shared_ptr<UnboundedWorkQueue> work_queue_; std::vector<DeviceInstance*> instances_; - CollectiveParams col_params_; + CollectiveParams* col_params_; std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_; std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_; std::unique_ptr<string> gpu_ring_order_; @@ -530,28 +536,28 @@ class RingGathererTest : public ::testing::Test { CancellationManager cancellation_manager_; }; -CollectiveParams SetUpCollectiveParams(const int num_devs_per_task, - const int num_tasks) { - CollectiveParams cp; +CollectiveParams* SetUpCollectiveParams(const int num_devs_per_task, + const int num_tasks) { + auto* cp = new CollectiveParams(); const int kNumDevs = num_devs_per_task * num_tasks; - cp.group.group_key = 1; - cp.group.group_size = kNumDevs; - cp.group.device_type = DeviceType("GPU"); - cp.group.num_tasks = num_tasks; - cp.instance.instance_key = 3; - cp.instance.type = GATHER_COLLECTIVE; - cp.instance.data_type = DataType(DT_FLOAT); - cp.instance.shape = TensorShape({kNumDevs * kNumDevs}); - cp.instance.impl_details.collective_name = "RingGather"; - cp.instance.impl_details.subdiv_offsets.push_back(0); - cp.is_source = false; + cp->group.group_key = 1; + cp->group.group_size = kNumDevs; + cp->group.device_type = DeviceType("GPU"); + cp->group.num_tasks = num_tasks; + cp->instance.instance_key = 3; + cp->instance.type = GATHER_COLLECTIVE; + cp->instance.data_type = DataType(DT_FLOAT); + cp->instance.shape = TensorShape({kNumDevs * kNumDevs}); + cp->instance.impl_details.collective_name = "RingGather"; + cp->instance.impl_details.subdiv_offsets.push_back(0); + cp->is_source = false; for (int i = 0; i < kNumDevs; ++i) { int task_id = i / num_devs_per_task; int dev_id = i % num_devs_per_task; string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id); string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id); - cp.group.task_names.push_back(task_name); - cp.group.device_names.push_back(device_name); + cp->group.task_names.push_back(task_name); + cp->group.device_names.push_back(device_name); } return cp; } @@ -559,23 +565,24 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task, TEST_F(RingGathererTest, InitializeParams) { const int kNumDevsPerTask = 8; const int kNumTasks = 3; - CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks); + CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks); + core::ScopedUnref unref(cp); - cp.default_rank = 0; - cp.instance.impl_details.subdiv_offsets = {}; - RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}}, + cp->default_rank = 0; + cp->instance.impl_details.subdiv_offsets = {}; + RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}}, {0}); - cp.instance.impl_details.subdiv_offsets = {0}; - RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}}, + cp->instance.impl_details.subdiv_offsets = {0}; + RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}}, {0}); - cp.default_rank = 3; - cp.instance.impl_details.subdiv_offsets = {}; - RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}}, + cp->default_rank = 3; + cp->instance.impl_details.subdiv_offsets = {}; + RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}}, {3}); } diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc index 3b153e4ca1d..89f8605ae25 100644 --- a/tensorflow/core/common_runtime/ring_reducer_test.cc +++ b/tensorflow/core/common_runtime/ring_reducer_test.cc @@ -138,7 +138,8 @@ static int64 kStepId = 123; class RingReducerTest : public ::testing::Test { protected: - RingReducerTest() : device_type_(DEVICE_CPU) {} + RingReducerTest() + : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {} #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM void InitGPUDevices() { @@ -155,6 +156,7 @@ class RingReducerTest : public ::testing::Test { stop_ = true; for (auto i : instances_) delete i; if (col_exec_) col_exec_->Unref(); + if (col_params_) col_params_->Unref(); } void Init(int num_workers, int num_devices, DataType dtype, @@ -203,24 +205,25 @@ class RingReducerTest : public ::testing::Test { col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get(), work_queue_); - col_params_.name = "test_collective"; + col_params_ = new CollectiveParams(); + col_params_->name = "test_collective"; static const int kGroupKey = 5; - col_params_.group.group_key = kGroupKey; - col_params_.group.device_type = device_type; - col_params_.group.group_size = num_workers * num_devices; + col_params_->group.group_key = kGroupKey; + col_params_->group.device_type = device_type; + col_params_->group.group_size = num_workers * num_devices; static const int kInstanceKey = 17; - col_params_.instance.instance_key = kInstanceKey; - col_params_.instance.impl_details.subdiv_offsets.clear(); - col_params_.instance.type = REDUCTION_COLLECTIVE; - col_params_.instance.impl_details.collective_name = "RingReduce"; - col_params_.instance.data_type = dtype; - col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs); - col_params_.subdiv_rank.resize(num_subdivs); + col_params_->instance.instance_key = kInstanceKey; + col_params_->instance.impl_details.subdiv_offsets.clear(); + col_params_->instance.type = REDUCTION_COLLECTIVE; + col_params_->instance.impl_details.collective_name = "RingReduce"; + col_params_->instance.data_type = dtype; + col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs); + col_params_->subdiv_rank.resize(num_subdivs); int subdiv_stride = num_devices / num_subdivs; for (int sdi = 0; sdi < num_subdivs; ++sdi) { - col_params_.instance.impl_details.subdiv_offsets.push_back(sdi * - subdiv_stride); - col_params_.subdiv_rank[sdi] = sdi * subdiv_stride; + col_params_->instance.impl_details.subdiv_offsets.push_back( + sdi * subdiv_stride); + col_params_->subdiv_rank[sdi] = sdi * subdiv_stride; } // Set up a local device ring order that's not just 0,1,2... @@ -242,23 +245,23 @@ class RingReducerTest : public ::testing::Test { // Set up all of the fake device contexts. for (int wi = 0; wi < num_workers; ++wi) { string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); - col_params_.group.num_devices_per_task[task_name] = num_devices; + col_params_->group.num_devices_per_task[task_name] = num_devices; for (int di = 0; di < num_devices; ++di) { string dev_name = strings::StrCat(task_name, "/cpu:", di); if (device_type == DEVICE_GPU) { dev_name = strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size()); } - col_params_.group.device_names.push_back(dev_name); - col_params_.group.task_names.push_back(task_name); + col_params_->group.device_names.push_back(dev_name); + col_params_->group.task_names.push_back(task_name); // Normally each device would set is_local to its own perspective but // this test runs in a single process so is_local is always true. - col_params_.task.is_local.push_back(true); + col_params_->task.is_local.push_back(true); for (int sdi = 0; sdi < num_subdivs; ++sdi) { int rotated_di = - (di + col_params_.instance.impl_details.subdiv_offsets[sdi]) % + (di + col_params_->instance.impl_details.subdiv_offsets[sdi]) % num_devices; - col_params_.instance.impl_details.subdiv_permutations[sdi].push_back( + col_params_->instance.impl_details.subdiv_permutations[sdi].push_back( wi * num_devices + local_ring_order[rotated_di]); } } @@ -267,7 +270,7 @@ class RingReducerTest : public ::testing::Test { for (int di = 0; di < num_devices; ++di) { int rank = wi * num_devices + di; instances_.push_back(new DeviceInstance( - rank, col_params_.group.device_names[rank], device_type_, this)); + rank, col_params_->group.device_names[rank], device_type_, this)); } } } @@ -413,39 +416,42 @@ class RingReducerTest : public ::testing::Test { : parent_(parent), dev_name_(dev_name), device_type_(device_type), - rank_(rank) { + rank_(rank), + col_params_(new CollectiveParams()) { TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_)) << "Couldn't find device " << dev_name << " existing devices: " << parent_->dev_mgr_->DebugString(); - col_params_.name = parent_->col_params_.name; - col_params_.group = parent_->col_params_.group; - col_params_.instance = parent->col_params_.instance; - col_params_.task.is_local = parent_->col_params_.task.is_local; - col_params_.subdiv_rank = parent_->col_params_.subdiv_rank; + col_params_->name = parent_->col_params_->name; + col_params_->group = parent_->col_params_->group; + col_params_->instance = parent->col_params_->instance; + col_params_->task.is_local = parent_->col_params_->task.is_local; + col_params_->subdiv_rank = parent_->col_params_->subdiv_rank; - int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size()); - int group_size = col_params_.group.group_size; + int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size()); + int group_size = col_params_->group.group_size; CHECK_EQ(group_size, - static_cast<int>(col_params_.group.device_names.size())); + static_cast<int>(col_params_->group.device_names.size())); // Id of this device is at rank position in first subdiv perm. int my_device_id = - col_params_.instance.impl_details.subdiv_permutations[0][rank]; - col_params_.default_rank = my_device_id; + col_params_->instance.impl_details.subdiv_permutations[0][rank]; + col_params_->default_rank = my_device_id; // Set rank for all other subdivs by finding that device_id. for (int sdi = 0; sdi < num_subdivs; ++sdi) { - for (int r = 0; r < static_cast<int>(col_params_.instance.impl_details + for (int r = 0; r < static_cast<int>(col_params_->instance.impl_details .subdiv_permutations[sdi] .size()); ++r) { if (my_device_id == - col_params_.instance.impl_details.subdiv_permutations[sdi][r]) { - col_params_.subdiv_rank[sdi] = r; + col_params_->instance.impl_details.subdiv_permutations[sdi][r]) { + col_params_->subdiv_rank[sdi] = r; break; } } } } + ~DeviceInstance() { col_params_->Unref(); } + void InitTensor(DataType dtype, const TensorShape& shape, const std::function<void(Tensor*)>& init_f) { tensor_ = @@ -466,10 +472,12 @@ class RingReducerTest : public ::testing::Test { } void DoReduce() { - merge_op_ = GetAdd(col_params_.instance.data_type, device_type_, device_); - final_op_ = GetDiv(col_params_.instance.data_type, device_type_, device_); - col_params_.merge_op = merge_op_.get(); - col_params_.final_op = final_op_.get(); + merge_op_ = + GetAdd(col_params_->instance.data_type, device_type_, device_); + final_op_ = + GetDiv(col_params_->instance.data_type, device_type_, device_); + col_params_->merge_op = merge_op_.get(); + col_params_->final_op = final_op_.get(); // Prepare an OpKernelContext. OpKernelContext::Params op_params; @@ -496,7 +504,7 @@ class RingReducerTest : public ::testing::Test { AllocatorAttributes generic_alloc_attr; op_params.output_attr_array = &generic_alloc_attr; std::unique_ptr<OpKernel> op = parent_->GetCollectiveReduce( - col_params_, &tensor_, DEVICE_CPU, device_); + *col_params_, &tensor_, DEVICE_CPU, device_); op_params.op_kernel = op.get(); OpKernelContext ctx(&op_params, 1); @@ -509,7 +517,7 @@ class RingReducerTest : public ::testing::Test { // Prepare a RingReducer instance. string exec_key = - strings::StrCat(col_params_.instance.instance_key, ":0:0"); + strings::StrCat(col_params_->instance.instance_key, ":0:0"); RingReducer* reducer = new RingReducer; core::ScopedUnref unref(reducer); auto col_ctx = std::make_shared<CollectiveContext>( @@ -535,7 +543,7 @@ class RingReducerTest : public ::testing::Test { int rank_; Tensor tensor_; Device* device_; - CollectiveParams col_params_; + CollectiveParams* col_params_; std::unique_ptr<OpKernel> merge_op_; std::unique_ptr<OpKernel> final_op_; std::unique_ptr<CollectiveAdapter> ca_; @@ -551,7 +559,7 @@ class RingReducerTest : public ::testing::Test { std::unique_ptr<DeviceResolverLocal> dev_resolver_; std::shared_ptr<UnboundedWorkQueue> work_queue_; std::vector<DeviceInstance*> instances_; - CollectiveParams col_params_; + CollectiveParams* col_params_; std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_; std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_; std::unique_ptr<string> gpu_ring_order_; @@ -560,28 +568,28 @@ class RingReducerTest : public ::testing::Test { CancellationManager cancellation_manager_; }; -CollectiveParams SetUpCollectiveParams(const int num_devs_per_task, - const int num_tasks) { - CollectiveParams cp; +CollectiveParams* SetUpCollectiveParams(const int num_devs_per_task, + const int num_tasks) { + auto cp = new CollectiveParams(); const int kNumDevs = num_devs_per_task * num_tasks; - cp.group.group_key = 1; - cp.group.group_size = kNumDevs; - cp.group.device_type = DeviceType("GPU"); - cp.group.num_tasks = num_tasks; - cp.instance.instance_key = 3; - cp.instance.type = REDUCTION_COLLECTIVE; - cp.instance.data_type = DataType(DT_FLOAT); - cp.instance.shape = TensorShape({kNumDevs}); - cp.instance.impl_details.collective_name = "RingReduce"; - cp.instance.impl_details.subdiv_offsets.push_back(0); - cp.is_source = false; + cp->group.group_key = 1; + cp->group.group_size = kNumDevs; + cp->group.device_type = DeviceType("GPU"); + cp->group.num_tasks = num_tasks; + cp->instance.instance_key = 3; + cp->instance.type = REDUCTION_COLLECTIVE; + cp->instance.data_type = DataType(DT_FLOAT); + cp->instance.shape = TensorShape({kNumDevs}); + cp->instance.impl_details.collective_name = "RingReduce"; + cp->instance.impl_details.subdiv_offsets.push_back(0); + cp->is_source = false; for (int i = 0; i < kNumDevs; ++i) { int task_id = i / num_devs_per_task; int dev_id = i % num_devs_per_task; string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id); string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id); - cp.group.task_names.push_back(task_name); - cp.group.device_names.push_back(device_name); + cp->group.task_names.push_back(task_name); + cp->group.device_names.push_back(device_name); } return cp; } @@ -589,28 +597,29 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task, TEST_F(RingReducerTest, InitializeParams) { const int kNumDevsPerTask = 8; const int kNumTasks = 3; - CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks); + CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks); + core::ScopedUnref unref(cp); - cp.default_rank = 0; - cp.instance.impl_details.subdiv_offsets = {0, 4}; - RunSubdivPermsTest(&cp, + cp->default_rank = 0; + cp->instance.impl_details.subdiv_offsets = {0, 4}; + RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}}, {0, 4}); - cp.instance.impl_details.subdiv_offsets = {0, -4}; - RunSubdivPermsTest(&cp, + cp->instance.impl_details.subdiv_offsets = {0, -4}; + RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}}, {0, 3}); - cp.default_rank = 3; - cp.instance.impl_details.subdiv_offsets = {3, -3}; - RunSubdivPermsTest(&cp, + cp->default_rank = 3; + cp->instance.impl_details.subdiv_offsets = {3, -3}; + RunSubdivPermsTest(cp, {{3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14, 15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18}, {4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9, @@ -622,13 +631,14 @@ TEST_F(RingReducerTest, AutomaticSubdivs) { const int kNumDevsPerTask = 8; const int kNumTasks = 3; const int kNumDevs = kNumDevsPerTask * kNumTasks; - CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks); + CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks); + core::ScopedUnref unref(cp); // Test automatic generation of subdiv offsets. - cp.default_rank = 0; - cp.instance.impl_details.subdiv_offsets.clear(); - RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}}, + cp->default_rank = 0; + cp->instance.impl_details.subdiv_offsets.clear(); + RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}}, {0}); // Set shape so that with 2 subdivs chunk_size is 3 MiB. This should cause 2 @@ -638,11 +648,11 @@ TEST_F(RingReducerTest, AutomaticSubdivs) { int num_chunks = kNumDevs * num_subdivs; size_t chunk_size = 3 * 1048576; // 3 MB size_t tensor_size = chunk_size * num_chunks; - cp.instance.shape = + cp->instance.shape = TensorShape({static_cast<int64>(tensor_size / DataTypeSize(DT_FLOAT))}); } - cp.instance.impl_details.subdiv_offsets.clear(); - RunSubdivPermsTest(&cp, + cp->instance.impl_details.subdiv_offsets.clear(); + RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, @@ -653,12 +663,13 @@ TEST_F(RingReducerTest, AutomaticSubdivs) { TEST_F(RingReducerTest, AutomaticSubdivUpperBound) { const int kNumDevsPerTask = 1; const int kNumTasks = 4; - CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks); + CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks); + core::ScopedUnref unref(cp); - cp.default_rank = 0; - cp.instance.impl_details.subdiv_offsets.clear(); - cp.instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)}); - RunSubdivPermsTest(&cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0}); + cp->default_rank = 0; + cp->instance.impl_details.subdiv_offsets.clear(); + cp->instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)}); + RunSubdivPermsTest(cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0}); } // TODO(b/113171733): change to use TEST_P. diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc index 1f380dab6f8..c5d846e1b57 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc @@ -137,13 +137,14 @@ void CollectiveParamResolverDistributed::CompleteGroupAsync( "running the same version of Tensorflow on all workers.")); return; } - CollectiveParams cp; - cp.group.group_key = request->group_key(); - cp.group.group_size = request->group_size(); - cp.group.device_type = DeviceType(request->device_type()); - cp.instance.type = CollectiveType(request->collective_type()); + auto* cp = new CollectiveParams(); + core::ScopedUnref unref(cp); + cp->group.group_key = request->group_key(); + cp->group.group_size = request->group_size(); + cp->group.device_type = DeviceType(request->device_type()); + cp->instance.type = CollectiveType(request->collective_type()); CompleteGroupDistributed( - request->device_attributes(), &cp, cancel_mgr, + request->device_attributes(), cp, cancel_mgr, [response, done](const Status& s, const GroupRec* gr) { if (s.ok()) { mutex_lock l(gr->mu); @@ -196,7 +197,7 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync( } StatusCallback done_and_cleanup = [cp, done](const Status& s) { done(s); - delete cp; + cp->Unref(); }; CompleteInstanceDistributed( request->device(), gr, cp, cancel_mgr, diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index 1c62b17fe54..8c9f107b9dc 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -127,6 +127,13 @@ class FakeCache : public TestWorkerCache { }; class DeviceResDistTest : public ::testing::Test { + public: + ~DeviceResDistTest() override { + for (auto& name_param : cp_) { + name_param.second->Unref(); + } + } + protected: void DefineWorkers(int num_workers, int num_devices, const string& device_type, bool nccl) { @@ -181,20 +188,20 @@ class DeviceResDistTest : public ::testing::Test { } } - CollectiveParams CreateCollectiveParams(int num_workers, int num_devices, - const string& device_type) { + CollectiveParams* CreateCollectiveParams(int num_workers, int num_devices, + const string& device_type) { const int kGroupKey = 5; const int kInstanceKey = 3; - CollectiveParams cp; - cp.group.group_key = kGroupKey; - cp.group.group_size = num_workers * num_devices; - cp.group.device_type = DeviceType(device_type); - cp.group.num_tasks = num_workers; - cp.instance.instance_key = kInstanceKey; - cp.instance.type = REDUCTION_COLLECTIVE; - cp.instance.data_type = DT_FLOAT; - cp.instance.shape = TensorShape({64}); - cp.instance.impl_details.subdiv_offsets.push_back(0); + auto* cp = new CollectiveParams(); + cp->group.group_key = kGroupKey; + cp->group.group_size = num_workers * num_devices; + cp->group.device_type = DeviceType(device_type); + cp->group.num_tasks = num_workers; + cp->instance.instance_key = kInstanceKey; + cp->instance.type = REDUCTION_COLLECTIVE; + cp->instance.data_type = DT_FLOAT; + cp->instance.shape = TensorShape({64}); + cp->instance.impl_details.subdiv_offsets.push_back(0); return cp; } @@ -217,7 +224,7 @@ class DeviceResDistTest : public ::testing::Test { int group_size) { Device* device = nullptr; TF_CHECK_OK(device_mgrs_[task_name]->LookupDevice(device_name, &device)); - CollectiveParams* cp = &cp_[device_name]; + CollectiveParams* cp = cp_[device_name]; CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name].get(); CHECK(cp_res); cp_res->CompleteParamsAsync( @@ -252,19 +259,19 @@ class DeviceResDistTest : public ::testing::Test { string device_name = strings::StrCat(task_name, "/device:CPU:", di); int idx = wi * num_devices + di; TF_ASSERT_OK(status_[device_name]); - EXPECT_EQ(cp_[device_name].default_rank, idx); - EXPECT_EQ(cp_[device_name].group.device_names.size(), dev_count); - EXPECT_EQ(cp_[device_name].group.device_names[idx], device_name); - EXPECT_EQ(cp_[device_name].group.task_names[idx], task_name); - ValidateDeviceResolver(cp_[device_name], task_name); + EXPECT_EQ(cp_[device_name]->default_rank, idx); + EXPECT_EQ(cp_[device_name]->group.device_names.size(), dev_count); + EXPECT_EQ(cp_[device_name]->group.device_names[idx], device_name); + EXPECT_EQ(cp_[device_name]->group.task_names[idx], task_name); + ValidateDeviceResolver(*cp_[device_name], task_name); if (idx > 0) { - EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key, - cp_[device_name].group.runtime_details.communicator_key); + EXPECT_EQ(cp_[dev0]->group.runtime_details.communicator_key, + cp_[device_name]->group.runtime_details.communicator_key); for (int i = 0; i < dev_count; ++i) { - EXPECT_EQ(cp_[dev0].group.device_names[i], - cp_[device_name].group.device_names[i]); - EXPECT_EQ(cp_[dev0].group.task_names[i], - cp_[device_name].group.task_names[i]); + EXPECT_EQ(cp_[dev0]->group.device_names[i], + cp_[device_name]->group.device_names[i]); + EXPECT_EQ(cp_[dev0]->group.task_names[i], + cp_[device_name]->group.task_names[i]); } } } @@ -287,6 +294,9 @@ class DeviceResDistTest : public ::testing::Test { for (int i = 0; i < num_devices; ++i) { string device_name = strings::StrCat(worker_name, "/device:", device_type, ":", i); + if (cp_.find(device_name) != cp_.end()) { + cp_[device_name]->Unref(); + } cp_[device_name] = CreateCollectiveParams(num_workers, num_devices, device_type); status_.erase(device_name); @@ -305,7 +315,7 @@ class DeviceResDistTest : public ::testing::Test { absl::flat_hash_map<string, std::vector<string>> dev_by_task_; absl::flat_hash_map<string, std::unique_ptr<FakeWorker>> workers_; // Below are keyed by device names; - absl::flat_hash_map<string, CollectiveParams> cp_; + absl::flat_hash_map<string, CollectiveParams*> cp_; absl::flat_hash_map<string, Status> status_; mutex mu_; int num_done_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc index 36e26ba9fe9..b1126471b5c 100644 --- a/tensorflow/core/framework/collective.cc +++ b/tensorflow/core/framework/collective.cc @@ -169,7 +169,7 @@ string CollectiveParams::ToString() const { CollectiveContext::CollectiveContext( CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator, const DeviceMgr* dev_mgr, OpKernelContext* ctx, - OpKernelContext::Params* op_params, const CollectiveParams& col_params, + OpKernelContext::Params* op_params, const CollectiveParams* col_params, const string& exec_key, int64 step_id, const Tensor* input, Tensor* output) : col_exec(col_exec), nccl_communicator(nccl_communicator), @@ -182,7 +182,7 @@ CollectiveContext::CollectiveContext( input(input), output(output), device(nullptr), - device_name(col_params.group.device_names[col_params.default_rank]) {} + device_name(col_params->group.device_names[col_params->default_rank]) {} /*static*/ int64 CollectiveExecutor::kInvalidId = -1; diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index cd4c28e1d2f..1b5b8f7789b 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -132,7 +132,7 @@ struct CollTaskParams { }; // Unique to a single CollectiveOp node. -struct CollectiveParams { +struct CollectiveParams : public core::RefCounted { CollGroupParams group; CollInstanceParams instance; CollTaskParams task; @@ -298,7 +298,7 @@ class CollectiveExecutor : public core::RefCounted { virtual void StartAbort(const Status& s) {} virtual void ExecuteAsync(OpKernelContext* ctx, - const CollectiveParams& col_params, + const CollectiveParams* col_params, const string& exec_key, StatusCallback done) { done(errors::Internal( "A collective Op has been called in a context in which " @@ -367,7 +367,7 @@ struct CollectiveContext { const DeviceMgr* dev_mgr; // Not owned OpKernelContext* op_ctx; // Not owned OpKernelContext::Params* op_params; // Not owned - const CollectiveParams& col_params; + const CollectiveParams* col_params; // Not owned const string exec_key; const int64 step_id; const Tensor* input; // Not owned @@ -380,7 +380,7 @@ struct CollectiveContext { NcclCommunicatorInterface* nccl_communicator, const DeviceMgr* dev_mgr, OpKernelContext* ctx, OpKernelContext::Params* op_params, - const CollectiveParams& col_params, const string& exec_key, + const CollectiveParams* col_params, const string& exec_key, int64 step_id, const Tensor* input, Tensor* output); }; diff --git a/tensorflow/core/kernels/collective_nccl.cc b/tensorflow/core/kernels/collective_nccl.cc index 44e0b07e9ad..04c6e8a337b 100644 --- a/tensorflow/core/kernels/collective_nccl.cc +++ b/tensorflow/core/kernels/collective_nccl.cc @@ -61,7 +61,7 @@ Status NcclBase::InitializeCollectiveParams(CollectiveParams* col_params) { Status NcclBase::InitializeCollectiveContext( std::shared_ptr<CollectiveContext> col_ctx) { col_ctx_ = col_ctx; - col_params_ = &col_ctx->col_params; + col_params_ = col_ctx->col_params; return collective_util::InitializeDeviceAndLocality( col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device, &col_ctx->device_locality); diff --git a/tensorflow/core/kernels/collective_nccl_reducer.cc b/tensorflow/core/kernels/collective_nccl_reducer.cc index 777c5fc8fc7..6aeec00c1da 100644 --- a/tensorflow/core/kernels/collective_nccl_reducer.cc +++ b/tensorflow/core/kernels/collective_nccl_reducer.cc @@ -88,6 +88,9 @@ void NcclReducer::Run(StatusCallback done) { } else { done_callback = std::move(done); } + // Hold a ref to col_params for the rest of this function. + col_params_->Ref(); + core::ScopedUnref unref(col_params_); col_ctx_->nccl_communicator->Enqueue(col_ctx_, std::move(done_callback)); // If no final_op, then this OpKernel is non-blocking. diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc index f8f18751d86..f9516fd8c13 100644 --- a/tensorflow/core/kernels/collective_ops.cc +++ b/tensorflow/core/kernels/collective_ops.cc @@ -53,7 +53,9 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c, class CollectiveOpV1Kernel : public AsyncOpKernel { public: explicit CollectiveOpV1Kernel(OpKernelConstruction* c) - : AsyncOpKernel(c), name_(name()) {} + : AsyncOpKernel(c), name_(name()), col_params_(new CollectiveParams()) {} + + ~CollectiveOpV1Kernel() override { col_params_->Unref(); } void ComputeAsync(OpKernelContext* c, DoneCallback done) override { CollectiveExecutor* col_exec = c->collective_executor(); @@ -88,28 +90,31 @@ class CollectiveOpV1Kernel : public AsyncOpKernel { // A string encoding instance, frame and iter to be handed off to // the implementation for use in generating RecvBuf keys. string GetCollectiveKey(OpKernelContext* c) { - return CollectiveKey(c, col_params_.group.group_key, - col_params_.instance.instance_key); + return CollectiveKey(c, col_params_->group.group_key, + col_params_->instance.instance_key); } // Returns false if calling invocation of ComputeAsync should return // immediately. bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec, const DoneCallback& done) { - if (col_params_.group.group_size > col_params_.group.device_names.size()) { + if (col_params_->group.group_size > + col_params_->group.device_names.size()) { // This is the first invocation: Finish initializing col_params_. // Schedule the `CompleteParamsAsync` call on a work queue that can handle // blocking work because it's not guaranteed that this call cannot block. - c->collective_executor()->RunClosure([this, c, done, col_exec]() { + col_params_->Ref(); + c->collective_executor()->RunClosure([this, c, col_exec, done]() { VLOG(1) << "CollectiveOpKernel CompleteParams for collective " - << col_params_.name << " device " << c->device()->name() - << " group " << col_params_.group.group_key << " instance " - << col_params_.instance.instance_key; + << col_params_->name << " device " << c->device()->name() + << " group " << col_params_->group.group_key << " instance " + << col_params_->instance.instance_key; col_exec->CompleteParamsAsync( - c->device()->attributes(), &col_params_, c->cancellation_manager(), + c->device()->attributes(), col_params_, c->cancellation_manager(), [this, c, done](const Status& s) { + core::ScopedUnref unref(col_params_); if (s.ok()) { - col_params_.instance.impl_details.dependencies = dependencies_; + col_params_->instance.impl_details.dependencies = dependencies_; ComputeAsync(c, done); } else { c->SetStatus(s); @@ -128,7 +133,7 @@ class CollectiveOpV1Kernel : public AsyncOpKernel { DoneCallback done) = 0; string name_; - CollectiveParams col_params_; + CollectiveParams* col_params_; std::vector<int32> dependencies_; }; @@ -136,25 +141,25 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel { public: explicit CollectiveGatherOpKernel(OpKernelConstruction* c) : CollectiveOpV1Kernel(c) { - col_params_.instance.type = GATHER_COLLECTIVE; - OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); + col_params_->instance.type = GATHER_COLLECTIVE; + OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size)); OP_REQUIRES( - c, col_params_.group.group_size > 0, + c, col_params_->group.group_size > 0, errors::InvalidArgument("group_size must be positive integer but got ", - col_params_.group.group_size)); - OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key)); + col_params_->group.group_size)); + OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key)); OP_REQUIRES_OK( - c, c->GetAttr("instance_key", &col_params_.instance.instance_key)); - OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type)); + c, c->GetAttr("instance_key", &col_params_->instance.instance_key)); + OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type)); OP_REQUIRES_OK( c, c->GetAttr("communication_hint", - &col_params_.instance.impl_details.communication_hint)); + &col_params_->instance.impl_details.communication_hint)); OP_REQUIRES_OK( c, c->GetAttr("timeout_seconds", - &col_params_.instance.impl_details.timeout_seconds)); + &col_params_->instance.impl_details.timeout_seconds)); const NodeDef& real_node = c->def(); - col_params_.name = strings::StrCat(real_node.name(), ": Gather"); - col_params_.group.device_type = c->device_type(); + col_params_->name = strings::StrCat(real_node.name(), ": Gather"); + col_params_->group.device_type = c->device_type(); } protected: @@ -162,8 +167,8 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel { DoneCallback done) override { auto output_shape = c->input(0).shape(); output_shape.set_dim( - 0, output_shape.dim_size(0) * col_params_.group.group_size); - col_params_.instance.shape = output_shape; + 0, output_shape.dim_size(0) * col_params_->group.group_size); + col_params_->instance.shape = output_shape; // Allocate output on the first pass through this function. This must be // done immediately, while we're still in the executor thread. Otherwise @@ -173,24 +178,24 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel { // Allocate the output tensor. Tensor* output = nullptr; OP_REQUIRES_OK_ASYNC( - c, c->allocate_output(0, col_params_.instance.shape, &output), done); + c, c->allocate_output(0, col_params_->instance.shape, &output), done); } if (!CanProceedWithCompute(c, col_exec, done)) return; - auto actual_done = [c, group_key = col_params_.group.group_key, - instance_key = col_params_.instance.instance_key, - done](const Status& s) { + auto actual_done = [c, col_params = col_params_, done](const Status& s) { VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync done for collective " << c->op_kernel().name() << " device " << c->device()->name() - << " group " << group_key << " instance " << instance_key - << " status " << s; + << " group " << col_params->group.group_key << " instance " + << col_params->instance.instance_key << " status " << s; OP_REQUIRES_OK_ASYNC(c, s, done); done(); + col_params->Unref(); }; VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync start for collective " - << col_params_.name << " device " << c->device()->name() - << " group " << col_params_.group.group_key << " instance " - << col_params_.instance.instance_key; + << col_params_->name << " device " << c->device()->name() + << " group " << col_params_->group.group_key << " instance " + << col_params_->instance.instance_key; + col_params_->Ref(); col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done); } @@ -207,18 +212,18 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel { public: explicit CollectiveReduceOpKernel(OpKernelConstruction* c) : CollectiveOpV1Kernel(c) { - col_params_.instance.type = REDUCTION_COLLECTIVE; - OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); + col_params_->instance.type = REDUCTION_COLLECTIVE; + OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size)); OP_REQUIRES( - c, col_params_.group.group_size > 0, + c, col_params_->group.group_size > 0, errors::InvalidArgument("group_size must be positive integer but got ", - col_params_.group.group_size)); - OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key)); + col_params_->group.group_size)); + OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key)); OP_REQUIRES_OK( - c, c->GetAttr("instance_key", &col_params_.instance.instance_key)); + c, c->GetAttr("instance_key", &col_params_->instance.instance_key)); OP_REQUIRES_OK( c, c->GetAttr("subdiv_offsets", - &col_params_.instance.impl_details.subdiv_offsets)); + &col_params_->instance.impl_details.subdiv_offsets)); string merge_op_name; OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name)); if (merge_op_name == "Max") { @@ -232,24 +237,26 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel { errors::InvalidArgument( "final_op must be one of {\"Id\", \"Div\"} but got ", final_op_name)); - OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type)); + OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type)); OP_REQUIRES_OK(c, c->GetAttr("wait_for", &dependencies_)); OP_REQUIRES_OK( c, c->GetAttr("communication_hint", - &col_params_.instance.impl_details.communication_hint)); + &col_params_->instance.impl_details.communication_hint)); OP_REQUIRES_OK( c, c->GetAttr("timeout_seconds", - &col_params_.instance.impl_details.timeout_seconds)); - VLOG(2) << "CollectiveReduce instance " << col_params_.instance.instance_key - << " merge_op " << merge_op_name << " final_op " << final_op_name + &col_params_->instance.impl_details.timeout_seconds)); + VLOG(2) << "CollectiveReduce instance " + << col_params_->instance.instance_key << " merge_op " + << merge_op_name << " final_op " << final_op_name << " communication_hint " - << col_params_.instance.impl_details.communication_hint - << " timeout " << col_params_.instance.impl_details.timeout_seconds; + << col_params_->instance.impl_details.communication_hint + << " timeout " + << col_params_->instance.impl_details.timeout_seconds; const NodeDef& real_node = c->def(); - col_params_.name = strings::StrCat(real_node.name(), ": Reduce(", - merge_op_name, ",", final_op_name, ")"); - col_params_.group.device_type = c->device_type(); + col_params_->name = strings::StrCat(real_node.name(), ": Reduce(", + merge_op_name, ",", final_op_name, ")"); + col_params_->group.device_type = c->device_type(); // Find the OpKernels by name, type and device type. NodeDef sub_node; @@ -257,12 +264,12 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel { sub_node.add_input(real_node.input(0)); sub_node.add_input(real_node.input(0)); sub_node.set_device(real_node.device()); - SetAttrValue(col_params_.instance.data_type, + SetAttrValue(col_params_->instance.data_type, &(*sub_node.mutable_attr())["T"]); merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node); final_op_ = BuildOpKernel(c, final_op_name, &sub_node); - col_params_.merge_op = merge_op_.get(); - col_params_.final_op = final_op_.get(); + col_params_->merge_op = merge_op_.get(); + col_params_->final_op = final_op_.get(); } protected: @@ -279,24 +286,24 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel { c->forward_input_or_allocate_output( {0}, 0, c->input(0).shape(), &output), done); - col_params_.instance.shape = c->input(0).shape(); + col_params_->instance.shape = c->input(0).shape(); } if (!CanProceedWithCompute(c, col_exec, done)) return; - auto actual_done = [c, group_key = col_params_.group.group_key, - instance_key = col_params_.instance.instance_key, - done](const Status& s) { + auto actual_done = [c, col_params = col_params_, done](const Status& s) { VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync done for collective " << c->op_kernel().name() << " device " << c->device()->name() - << " group " << group_key << " instance " << instance_key - << " status " << s; + << " group " << col_params->group.group_key << " instance " + << col_params->instance.instance_key << " status " << s; OP_REQUIRES_OK_ASYNC(c, s, done); done(); + col_params->Unref(); }; VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync start for collective " - << col_params_.name << " device " << c->device()->name() - << " group " << col_params_.group.group_key << " instance " - << col_params_.instance.instance_key; + << col_params_->name << " device " << c->device()->name() + << " group " << col_params_->group.group_key << " instance " + << col_params_->instance.instance_key; + col_params_->Ref(); col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done); } @@ -315,29 +322,29 @@ class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel { public: explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c) : CollectiveOpV1Kernel(c) { - col_params_.instance.type = BROADCAST_COLLECTIVE; - OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); + col_params_->instance.type = BROADCAST_COLLECTIVE; + OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size)); OP_REQUIRES( - c, col_params_.group.group_size > 0, + c, col_params_->group.group_size > 0, errors::InvalidArgument("group_size must be positive integer but got ", - col_params_.group.group_size)); - OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key)); + col_params_->group.group_size)); + OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key)); OP_REQUIRES_OK( - c, c->GetAttr("instance_key", &col_params_.instance.instance_key)); - OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type)); - OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape)); + c, c->GetAttr("instance_key", &col_params_->instance.instance_key)); + OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type)); + OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape)); OP_REQUIRES_OK( c, c->GetAttr("communication_hint", - &col_params_.instance.impl_details.communication_hint)); + &col_params_->instance.impl_details.communication_hint)); OP_REQUIRES_OK( c, c->GetAttr("timeout_seconds", - &col_params_.instance.impl_details.timeout_seconds)); - col_params_.is_source = true; - col_params_.instance.impl_details.subdiv_offsets = {0}; + &col_params_->instance.impl_details.timeout_seconds)); + col_params_->is_source = true; + col_params_->instance.impl_details.subdiv_offsets = {0}; - col_params_.name = - strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")"); - col_params_.group.device_type = c->device_type(); + col_params_->name = + strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")"); + col_params_->group.device_type = c->device_type(); } protected: @@ -352,30 +359,30 @@ class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel { Tensor* output = nullptr; OP_REQUIRES_OK_ASYNC(c, c->forward_input_or_allocate_output( - {0}, 0, col_params_.instance.shape, &output), + {0}, 0, col_params_->instance.shape, &output), done); } if (!CanProceedWithCompute(c, col_exec, done)) return; OP_REQUIRES_ASYNC( - c, col_params_.instance.shape.IsSameSize(c->input(0).shape()), - errors::Internal("Declared shape of op ", col_params_.name, + c, col_params_->instance.shape.IsSameSize(c->input(0).shape()), + errors::Internal("Declared shape of op ", col_params_->name, " does not match shape of input"), done); - auto actual_done = [c, group_key = col_params_.group.group_key, - instance_key = col_params_.instance.instance_key, - done](const Status& s) { + auto actual_done = [c, col_params = col_params_, done](const Status& s) { VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync done for collective " << c->op_kernel().name() << " device " << c->device()->name() - << " group " << group_key << " instance " << instance_key - << " status " << s; + << " group " << col_params->group.group_key << " instance " + << col_params->instance.instance_key << " status " << s; OP_REQUIRES_OK_ASYNC(c, s, done); done(); + col_params->Unref(); }; VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync start for collective " - << col_params_.name << " device " << c->device()->name() - << " group " << col_params_.group.group_key << " instance " - << col_params_.instance.instance_key; + << col_params_->name << " device " << c->device()->name() + << " group " << col_params_->group.group_key << " instance " + << col_params_->instance.instance_key; + col_params_->Ref(); col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done); } @@ -392,29 +399,29 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel { public: explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c) : CollectiveOpV1Kernel(c) { - col_params_.instance.type = BROADCAST_COLLECTIVE; - OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); + col_params_->instance.type = BROADCAST_COLLECTIVE; + OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size)); OP_REQUIRES( - c, col_params_.group.group_size > 0, + c, col_params_->group.group_size > 0, errors::InvalidArgument("group_size must be positive integer but got ", - col_params_.group.group_size)); - OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key)); + col_params_->group.group_size)); + OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key)); OP_REQUIRES_OK( - c, c->GetAttr("instance_key", &col_params_.instance.instance_key)); - OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type)); - OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape)); + c, c->GetAttr("instance_key", &col_params_->instance.instance_key)); + OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type)); + OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape)); OP_REQUIRES_OK( c, c->GetAttr("communication_hint", - &col_params_.instance.impl_details.communication_hint)); + &col_params_->instance.impl_details.communication_hint)); OP_REQUIRES_OK( c, c->GetAttr("timeout_seconds", - &col_params_.instance.impl_details.timeout_seconds)); - col_params_.is_source = false; - col_params_.instance.impl_details.subdiv_offsets = {0}; + &col_params_->instance.impl_details.timeout_seconds)); + col_params_->is_source = false; + col_params_->instance.impl_details.subdiv_offsets = {0}; - col_params_.name = - strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")"); - col_params_.group.device_type = c->device_type(); + col_params_->name = + strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")"); + col_params_->group.device_type = c->device_type(); } protected: @@ -428,24 +435,24 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel { // No input, so must allocate output. Tensor* output = nullptr; OP_REQUIRES_OK_ASYNC( - c, c->allocate_output(0, col_params_.instance.shape, &output), done); + c, c->allocate_output(0, col_params_->instance.shape, &output), done); } if (!CanProceedWithCompute(c, col_exec, done)) return; - auto actual_done = [c, group_key = col_params_.group.group_key, - instance_key = col_params_.instance.instance_key, - done](const Status& s) { + auto actual_done = [c, col_params = col_params_, done](const Status& s) { VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync done for collective " << c->op_kernel().name() << " device " << c->device()->name() - << " group " << group_key << " instance_key " << instance_key - << " status " << s; + << " group " << col_params->group.group_key << " instance_key " + << col_params->instance.instance_key << " status " << s; OP_REQUIRES_OK_ASYNC(c, s, done); done(); + col_params->Unref(); }; VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync start for collective " - << col_params_.name << " device " << c->device()->name() - << " group " << col_params_.group.group_key << " instance " - << col_params_.instance.instance_key; + << col_params_->name << " device " << c->device()->name() + << " group " << col_params_->group.group_key << " instance " + << col_params_->instance.instance_key; + col_params_->Ref(); col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done); } @@ -534,8 +541,8 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel { << col_params->instance.instance_key; auto done_with_cleanup = [col_params, done = std::move(done)]() { - delete col_params; done(); + col_params->Unref(); }; // Allocate the output tensor, trying to reuse the input. @@ -577,7 +584,7 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel { << " group " << col_params->group.group_key << " instance " << col_params->instance.instance_key; col_exec->ExecuteAsync( - c, *col_params, + c, col_params, CollectiveKey(c, col_params->group.group_key, col_params->instance.instance_key), actual_done); @@ -673,8 +680,8 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel { col_params->instance.shape = output_shape; auto done_with_cleanup = [col_params, done = std::move(done)]() { - delete col_params; done(); + col_params->Unref(); }; Tensor* output = nullptr; @@ -714,7 +721,7 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel { << " group " << col_params->group.group_key << " instance " << col_params->instance.instance_key; col_exec->ExecuteAsync( - c, *col_params, + c, col_params, CollectiveKey(c, col_params->group.group_key, col_params->instance.instance_key), actual_done); @@ -797,8 +804,8 @@ class CollectiveBcastSendV2OpKernel : public AsyncOpKernel { << col_params->instance.instance_key; auto done_with_cleanup = [col_params, done = std::move(done)]() { - delete col_params; done(); + col_params->Unref(); }; // Allocate the output tensor, trying to reuse the input. @@ -840,7 +847,7 @@ class CollectiveBcastSendV2OpKernel : public AsyncOpKernel { << " group " << col_params->group.group_key << " instance " << col_params->instance.instance_key; col_exec->ExecuteAsync( - c, *col_params, + c, col_params, CollectiveKey(c, col_params->group.group_key, col_params->instance.instance_key), actual_done); @@ -905,8 +912,8 @@ class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel { auto col_params = new CollectiveParams(); auto done_with_cleanup = [col_params, done = std::move(done)]() { - delete col_params; done(); + col_params->Unref(); }; OP_REQUIRES_OK_ASYNC( @@ -969,7 +976,7 @@ class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel { << " group " << col_params->group.group_key << " instance " << col_params->instance.instance_key; col_exec->ExecuteAsync( - c, *col_params, + c, col_params, CollectiveKey(c, col_params->group.group_key, col_params->instance.instance_key), actual_done); diff --git a/tensorflow/core/nccl/collective_communicator.cc b/tensorflow/core/nccl/collective_communicator.cc index 233f1ecefd3..2f0659eb121 100644 --- a/tensorflow/core/nccl/collective_communicator.cc +++ b/tensorflow/core/nccl/collective_communicator.cc @@ -69,17 +69,17 @@ std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator() { void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx, StatusCallback done) { - const CollectiveParams& col_params = col_ctx->col_params; - const int num_global_devices = col_params.group.group_size; - const int num_local_devices = col_params.group.num_devices_per_task.at( - col_params.group.task_names[col_params.default_rank]); + const CollectiveParams* col_params = col_ctx->col_params; + const int num_global_devices = col_params->group.group_size; + const int num_local_devices = col_params->group.num_devices_per_task.at( + col_params->group.task_names[col_params->default_rank]); const string nccl_collective_key = NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id); auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream(); auto* gpu_info = col_ctx->op_ctx->device()->tensorflow_gpu_device_info(); auto participant = absl::make_unique<NcclManager::Participant>( compute_stream->parent(), compute_stream, gpu_info, col_ctx->input, - col_ctx->output, col_ctx->col_params.default_rank, + col_ctx->output, col_ctx->col_params->default_rank, /*done_callback=*/nullptr); CancellationManager* cancel_mgr = col_ctx->op_ctx->cancellation_manager(); if (cancel_mgr == nullptr) { @@ -105,15 +105,24 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx, } NcclManager::Context context( nccl_collective_key, num_local_devices, num_global_devices, - col_params.group.runtime_details.communicator_key, - col_params.source_rank); - VLOG(1) << "NcclCommunicator::Enqueue type " << col_params.instance.type - << " num_tasks " << col_params.group.num_tasks << " current task " - << col_params.group.task_names[col_params.default_rank] + col_params->group.runtime_details.communicator_key, + col_params->source_rank); + VLOG(1) << "NcclCommunicator::Enqueue type " << col_params->instance.type + << " num_tasks " << col_params->group.num_tasks << " current task " + << col_params->group.task_names[col_params->default_rank] << " num local devices " << num_local_devices << " num global devices " << num_global_devices << " device " << col_ctx->device_name << " instance " - << col_params.instance.instance_key; + << col_params->instance.instance_key; + // Hold a ref to col_params for the rest of this function. + // NOTE: an alternate design can be one in which CollectiveParams is not + // refcounted. In such a design, we would need to ensure that the + // done_callback of each participant is called only after this function is + // done with accessing the params. This would likely require some + // coordination mechanism, and may even require the participant thread to + // block until after UnblockDependencies is called below. + col_params->Ref(); + core::ScopedUnref unref(col_params); // `AddTo*` performs consistency checks for the NCCL call and enqueues the // `Participant` struct locally. When all local participants with this // `nccl_collective_key` have called `AddToAllReduce` and @@ -123,10 +132,11 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx, // The `NcclManager` uses a dedicated CUDA stream for NCCL kernels. At this // point, it synchronizes the NCCL stream with the compute stream, and then // enqueues the NCCL kernel on the NCCL stream. - switch (col_params.instance.type) { + switch (col_params->instance.type) { case REDUCTION_COLLECTIVE: { ncclRedOp_t reduction_op; - Status s = ReductionOp(col_params.merge_op->type_string(), &reduction_op); + Status s = + ReductionOp(col_params->merge_op->type_string(), &reduction_op); if (!s.ok()) { participant->done_callback(s); return; @@ -140,7 +150,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx, break; } case BROADCAST_COLLECTIVE: { - if (col_params.is_source) { + if (col_params->is_source) { nccl_manager_.AddBroadcastSend(std::move(participant), context); } else { nccl_manager_.AddBroadcastRecv(std::move(participant), context); @@ -149,7 +159,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx, } default: { participant->done_callback(errors::Internal("Unexpected CollectiveType ", - col_params.instance.type)); + col_params->instance.type)); return; } } @@ -175,7 +185,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx, // ready to go. profiler::TraceMe activity("WaitForDependencies", profiler::TraceMeLevel::kInfo); - col_ctx->col_exec->WaitForDependencies(col_params); + col_ctx->col_exec->WaitForDependencies(*col_params); nccl_manager_.SignalMultiNodeReady(nccl_collective_key); } { @@ -184,7 +194,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx, // implementation of `UnblockDependencies` keeps track of the number of // devices that have launched. profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo); - col_ctx->col_exec->UnblockDependencies(col_params); + col_ctx->col_exec->UnblockDependencies(*col_params); } }