diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index b32a7385f65..392c0e9592b 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -249,6 +249,20 @@ class BaseGPUDevice::StreamGroupFactory { VLOG(2) << "Created stream[" << stream_group_within_gpu << "] = " << group->compute; +#if TENSORFLOW_USE_ROCM + // ROCm streams are lightweight and will not necessarily trigger device + // queue init until they are first used. For optimal performance, + // compute and nccl streams must be immediate siblings. + group->nccl = new se::Stream(executor); + group->nccl->Init(); + VLOG(2) << "Created nccl_stream[" << stream_group_within_gpu + << "] = " << group->nccl; + + // Force underlying resource creation now. + group->compute->ThenWaitFor(group->nccl); + group->nccl->ThenWaitFor(group->compute); +#endif + group->host_to_device = new se::Stream(executor); group->host_to_device->Init(); VLOG(2) << "Created host_to_device_stream[" << stream_group_within_gpu @@ -371,8 +385,12 @@ Status BaseGPUDevice::Init(const SessionOptions& options) { streams_.push_back(StreamGroupFactory::Global().GetOrCreate( tf_gpu_id_, i, executor_, options.config.gpu_options())); device_contexts_.push_back(new GPUDeviceContext( - i, streams_.back()->compute, streams_.back()->host_to_device, - streams_.back()->device_to_host, streams_.back()->device_to_device)); + i, streams_.back()->compute, +#if TENSORFLOW_USE_ROCM + streams_.back()->nccl, +#endif + streams_.back()->host_to_device, streams_.back()->device_to_host, + streams_.back()->device_to_device)); } em_ = EventMgrFactory::Singleton()->GetEventMgr(executor_, diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h index cbba89d0d05..ae7611fee72 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.h +++ b/tensorflow/core/common_runtime/gpu/gpu_device.h @@ -137,6 +137,9 @@ class BaseGPUDevice : public LocalDevice { friend class GPUDeviceTestHelper; struct StreamGroup { se::Stream* compute = nullptr; +#if TENSORFLOW_USE_ROCM + se::Stream* nccl = nullptr; +#endif se::Stream* host_to_device = nullptr; se::Stream* device_to_host = nullptr; gtl::InlinedVector<se::Stream*, 4> device_to_device; diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h index eab46b79c17..4f6325a6d19 100644 --- a/tensorflow/core/common_runtime/gpu_device_context.h +++ b/tensorflow/core/common_runtime/gpu_device_context.h @@ -30,18 +30,28 @@ class GPUDeviceContext : public DeviceContext { public: // Does not take ownership of streams. GPUDeviceContext(int stream_id, se::Stream* stream, +#if TENSORFLOW_USE_ROCM + se::Stream* nccl_stream, +#endif se::Stream* host_to_device_stream, se::Stream* device_to_host_stream, gtl::InlinedVector<se::Stream*, 4> device_to_device_stream) : stream_id_(stream_id), stream_(stream), +#if TENSORFLOW_USE_ROCM + nccl_stream_(nccl_stream), +#endif host_to_device_stream_(host_to_device_stream), device_to_host_stream_(device_to_host_stream), - device_to_device_stream_(device_to_device_stream) {} + device_to_device_stream_(device_to_device_stream) { + } ~GPUDeviceContext() override {} se::Stream* stream() const override { return stream_; } +#if TENSORFLOW_USE_ROCM + se::Stream* nccl_stream() const { return nccl_stream_; } +#endif se::Stream* host_to_device_stream() const { return host_to_device_stream_; } se::Stream* device_to_host_stream() const { return device_to_host_stream_; } se::Stream* device_to_device_stream(int index) const { @@ -72,6 +82,10 @@ class GPUDeviceContext : public DeviceContext { // The default primary stream to use for this context. // All the memory belongs to this stream. se::Stream* stream_; +#if TENSORFLOW_USE_ROCM + // The stream to use for nccl operations. + se::Stream* nccl_stream_; +#endif // The stream to use for copying data from host into GPU. se::Stream* host_to_device_stream_; // The stream to use for copying data from GPU to host. diff --git a/tensorflow/core/kernels/collective_nccl_broadcaster.cc b/tensorflow/core/kernels/collective_nccl_broadcaster.cc index 6e1da95faa7..59aecd90309 100644 --- a/tensorflow/core/kernels/collective_nccl_broadcaster.cc +++ b/tensorflow/core/kernels/collective_nccl_broadcaster.cc @@ -32,9 +32,8 @@ void NcclBroadcaster::Run(StatusCallback done) { string nccl_collective_key = NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id); auto participant = absl::make_unique<NcclManager::Participant>( - compute_stream->parent(), compute_stream, gpu_info->event_mgr, - gpu_info->gpu_id, col_ctx_->input, col_ctx_->output, - col_params_->default_rank, std::move(done)); + compute_stream->parent(), compute_stream, gpu_info, col_ctx_->input, + col_ctx_->output, col_params_->default_rank, std::move(done)); VLOG(1) << "NcclBroadcast calling NcclManager::AddBroadcastSend/Recv num_tasks " << col_params_->group.num_tasks << " current task " diff --git a/tensorflow/core/kernels/collective_nccl_gatherer.cc b/tensorflow/core/kernels/collective_nccl_gatherer.cc index 144d830befb..e219dffdc33 100644 --- a/tensorflow/core/kernels/collective_nccl_gatherer.cc +++ b/tensorflow/core/kernels/collective_nccl_gatherer.cc @@ -32,9 +32,8 @@ void NcclGatherer::Run(StatusCallback done) { string nccl_collective_key = NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id); auto participant = absl::make_unique<NcclManager::Participant>( - compute_stream->parent(), compute_stream, gpu_info->event_mgr, - gpu_info->gpu_id, col_ctx_->input, col_ctx_->output, - col_params_->default_rank, std::move(done)); + compute_stream->parent(), compute_stream, gpu_info, col_ctx_->input, + col_ctx_->output, col_params_->default_rank, std::move(done)); VLOG(1) << "NcclGatherer calling NcclManager::AddToAllGather num_tasks " << col_params_->group.num_tasks << " current task " << col_params_->instance.task_names[col_params_->default_rank] diff --git a/tensorflow/core/kernels/collective_nccl_reducer.cc b/tensorflow/core/kernels/collective_nccl_reducer.cc index 873e4e3aa6c..399c537ad33 100644 --- a/tensorflow/core/kernels/collective_nccl_reducer.cc +++ b/tensorflow/core/kernels/collective_nccl_reducer.cc @@ -118,9 +118,8 @@ void NcclReducer::Run(StatusCallback done) { nccl_done.Notify(); }; auto participant = absl::make_unique<NcclManager::Participant>( - compute_stream->parent(), compute_stream, gpu_info->event_mgr, - gpu_info->gpu_id, col_ctx_->input, col_ctx_->output, - col_params_->default_rank, std::move(done_callback)); + compute_stream->parent(), compute_stream, gpu_info, col_ctx_->input, + col_ctx_->output, col_params_->default_rank, std::move(done_callback)); VLOG(1) << "NcclReducer calling NcclManager::AddToAllReduce num_tasks " << col_params_->group.num_tasks << " current task " << col_params_->instance.task_names[col_params_->default_rank] diff --git a/tensorflow/core/kernels/nccl_ops.cc b/tensorflow/core/kernels/nccl_ops.cc index 9ccf591058e..bc028e8197c 100644 --- a/tensorflow/core/kernels/nccl_ops.cc +++ b/tensorflow/core/kernels/nccl_ops.cc @@ -108,9 +108,8 @@ class NcclAllReduceOpKernel : public NcclReduceOpBase { auto* compute_stream = c->op_device_context()->stream(); auto* gpu_info = c->device()->tensorflow_gpu_device_info(); auto participant = absl::make_unique<NcclManager::Participant>( - compute_stream->parent(), compute_stream, gpu_info->event_mgr, - gpu_info->gpu_id, input, output, /*global_rank=*/-1, - std::move(actual_done)); + compute_stream->parent(), compute_stream, gpu_info, input, output, + /*global_rank=*/-1, std::move(actual_done)); NcclManager::instance()->AddToAllReduce( std::move(participant), {GetCollectiveKey(c), @@ -140,9 +139,8 @@ class NcclReduceSendKernel : public NcclReduceOpBase { auto* compute_stream = c->op_device_context()->stream(); auto* gpu_info = c->device()->tensorflow_gpu_device_info(); auto participant = absl::make_unique<NcclManager::Participant>( - compute_stream->parent(), compute_stream, gpu_info->event_mgr, - gpu_info->gpu_id, &c->input(0), /*output=*/nullptr, /*global_rank=*/-1, - std::move(actual_done)); + compute_stream->parent(), compute_stream, gpu_info, &c->input(0), + /*output=*/nullptr, /*global_rank=*/-1, std::move(actual_done)); NcclManager::instance()->AddReduceSend( std::move(participant), {GetCollectiveKey(c), @@ -177,9 +175,8 @@ class NcclReduceRecvKernel : public NcclReduceOpBase { auto* compute_stream = c->op_device_context()->stream(); auto* gpu_info = c->device()->tensorflow_gpu_device_info(); auto participant = absl::make_unique<NcclManager::Participant>( - compute_stream->parent(), compute_stream, gpu_info->event_mgr, - gpu_info->gpu_id, input, output, /*global_rank=*/-1, - std::move(actual_done)); + compute_stream->parent(), compute_stream, gpu_info, input, output, + /*global_rank=*/-1, std::move(actual_done)); NcclManager::instance()->AddReduceRecv( std::move(participant), {GetCollectiveKey(c), @@ -212,9 +209,8 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase { auto* compute_stream = c->op_device_context()->stream(); auto* gpu_info = c->device()->tensorflow_gpu_device_info(); auto participant = absl::make_unique<NcclManager::Participant>( - compute_stream->parent(), compute_stream, gpu_info->event_mgr, - gpu_info->gpu_id, &c->input(0), /*output=*/nullptr, /*global_rank=*/-1, - std::move(actual_done)); + compute_stream->parent(), compute_stream, gpu_info, &c->input(0), + /*output=*/nullptr, /*global_rank=*/-1, std::move(actual_done)); NcclManager::instance()->AddBroadcastSend( std::move(participant), {GetCollectiveKey(c), /*num_local_devices=*/num_devices(), @@ -249,9 +245,8 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase { auto* compute_stream = c->op_device_context()->stream(); auto* gpu_info = c->device()->tensorflow_gpu_device_info(); auto participant = absl::make_unique<NcclManager::Participant>( - compute_stream->parent(), compute_stream, gpu_info->event_mgr, - gpu_info->gpu_id, /*input=*/nullptr, output, /*global_rank=*/-1, - std::move(actual_done)); + compute_stream->parent(), compute_stream, gpu_info, + /*input=*/nullptr, output, /*global_rank=*/-1, std::move(actual_done)); NcclManager::instance()->AddBroadcastRecv( std::move(participant), {GetCollectiveKey(c), /*num_local_devices=*/num_devices(), diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index 24ea6416084..1931d0fa47c 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -30,12 +30,13 @@ cc_library( ]), copts = tf_copts(), deps = if_cuda([ - "@com_google_absl//absl/memory", "@local_config_nccl//:nccl", ]) + if_rocm([ "@local_config_rocm//rocm:rccl", + "//tensorflow/core:gpu_runtime", ]) + if_cuda_or_rocm([ "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", @@ -51,7 +52,6 @@ tf_cuda_cc_test( srcs = ["nccl_manager_test.cc"], tags = tf_cuda_tests_tags() + [ "no_cuda_on_cpu_tap", # TODO(b/120284216): re-enable multi_gpu - "no_rocm", ], deps = [ "//tensorflow/core:test", diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index 4a439a46525..2bdfbe34584 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -40,6 +40,7 @@ using se::rocm::ScopedActivateExecutorContext; #define cudaGetDevice hipGetDevice #define cudaSetDevice hipSetDevice #define cudaSuccess hipSuccess +int NcclManager::instance_count = 0; #endif #define NCCL_RETURN_IF_ERROR(...) \ @@ -69,7 +70,12 @@ struct NcclManager::NcclStream : public core::RefCounted { // The stream on which to run the nccl collective. // This is a different stream than the tensorflow compute stream. +#if TENSORFLOW_USE_ROCM + // On ROCm, we borrow the nccl stream from the device context. + se::Stream* stream = nullptr; +#else std::unique_ptr<se::Stream> stream; +#endif // `mu` protects access to `pending_launches_`, which is the list of // collectives ready but whose kernels are yet to be launched. When the @@ -155,6 +161,16 @@ struct NcclManager::Collective : public core::RefCounted { single_node(num_local_devices_in == num_global_devices_in), communicator_key(communicator_key_in) { participants.reserve(num_local_devices_in); +#if TENSORFLOW_USE_ROCM + // On ROCm platform, this allows caller to either use the singleton instance + // or to manage one non-singleton NcclManager instance. + // For example, the nccl_manager_test will use both paradigms in the same + // executable, but not running concurrently (which would hang otherwise). + if (NcclManager::instance_count > 1) { + status = errors::Internal( + "ROCm cannot use multi-node NCCL collectives on a single node"); + } +#endif } const string collective_key; // A unique key for debugging. @@ -193,9 +209,17 @@ struct NcclManager::Collective : public core::RefCounted { Status status; }; -NcclManager::NcclManager() { VLOG(2) << "New NcclManager " << this; } +NcclManager::NcclManager() { + VLOG(2) << "New NcclManager " << this; +#if TENSORFLOW_USE_ROCM + ++instance_count; +#endif +} NcclManager::~NcclManager() { VLOG(2) << "~NcclManager " << this; +#if TENSORFLOW_USE_ROCM + --instance_count; +#endif for (auto& it : device_to_comm_streams_) { for (NcclStream* nccl_stream : it.second) { { @@ -209,6 +233,12 @@ NcclManager::~NcclManager() { } NcclManager* NcclManager::instance() { static NcclManager* instance = new NcclManager(); +#if TENSORFLOW_USE_ROCM + // singleton does not count against total instances + // see comment above in Collective constructor concerning ROCm platform + static std::once_flag once; + std::call_once(once, [] { --NcclManager::instance_count; }); +#endif return instance; } @@ -313,8 +343,12 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, if (nccl_stream == nullptr) { nccl_stream = new NcclStream(); nccl_stream->executor = executor; +#if TENSORFLOW_USE_ROCM + nccl_stream->stream = collective->participants[i]->context->nccl_stream(); +#else nccl_stream->stream.reset(new se::Stream(executor)); nccl_stream->stream->Init(); +#endif streams.emplace_back(nccl_stream); used_streams.insert(nccl_stream); @@ -604,7 +638,11 @@ void NcclManager::RunCollective(Collective* collective) { } void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { +#if TENSORFLOW_USE_ROCM + se::Stream* comm_stream = nccl_stream->stream; +#else se::Stream* comm_stream = nccl_stream->stream.get(); +#endif ScopedActivateExecutorContext scoped_context(nccl_stream->executor); const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>( comm_stream->implementation()->GpuStreamMemberHack()); diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index 9f4ef255ab3..b0b4441b776 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -32,8 +32,10 @@ limitations under the License. #include "third_party/nccl/nccl.h" #elif TENSORFLOW_USE_ROCM #include "rocm/include/rccl/rccl.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" #endif #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor.h" @@ -53,6 +55,10 @@ class NcclManager { static NcclManager* instance(); +#if TENSORFLOW_USE_ROCM + static int instance_count; +#endif + // Calls `ncclGetUniqueId` and returns the id as a string. The returned value // may be shared with other participants on different nodes and passed in to // multi-node collective invocations. @@ -61,12 +67,15 @@ class NcclManager { // A participant in a Collective. struct Participant { Participant(se::StreamExecutor* executor, se::Stream* tensor_stream, - EventMgr* event_mgr, int gpu_device_id, const Tensor* input, + const DeviceBase::GpuDeviceInfo* info, const Tensor* input, Tensor* output, int global_rank, DoneCallback done_callback) : executor(executor), tensor_stream(tensor_stream), - event_mgr(event_mgr), - gpu_device_id(gpu_device_id), + event_mgr(info->event_mgr), + gpu_device_id(info->gpu_id), +#if TENSORFLOW_USE_ROCM + context(static_cast<GPUDeviceContext*>(info->default_context)), +#endif input(input), input_event(nullptr), output(output), @@ -101,6 +110,10 @@ class NcclManager { const int gpu_device_id; +#if TENSORFLOW_USE_ROCM + GPUDeviceContext* const context; +#endif + // Owned by the caller, who must keep it live until `done_callback` is // called. Is NULL for participants that only receive data. const Tensor* input; diff --git a/tensorflow/core/nccl/nccl_manager_test.cc b/tensorflow/core/nccl/nccl_manager_test.cc index 9b650c66fa7..8d4e48c9e33 100644 --- a/tensorflow/core/nccl/nccl_manager_test.cc +++ b/tensorflow/core/nccl/nccl_manager_test.cc @@ -303,13 +303,13 @@ class NcclManagerTest : public ::testing::Test { for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) { auto* device = this->GetDevice(local_rank); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; const int global_rank = node * num_ranks_per_node + local_rank; auto participant = absl::make_unique<NcclManager::Participant>( - device->executor(), stream, event_mgr, device->gpu_id(), - &test_case->ins[global_rank], &test_case->outs[global_rank], - global_rank, this->CreateDoneCallback(test_case.get())); + device->executor(), stream, info, &test_case->ins[global_rank], + &test_case->outs[global_rank], global_rank, + this->CreateDoneCallback(test_case.get())); node_states[node].nccl_manager.AddToAllReduce( std::move(participant), {collective_key, num_ranks_per_node, num_global_ranks, @@ -351,7 +351,7 @@ class NcclManagerTest : public ::testing::Test { src_global_rank, local_rank, &node_states, &collective_key, &communicator_key, &test_case]() { auto* device = this->GetDevice(local_rank); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; const int global_rank = node * num_ranks_per_node + local_rank; auto* input = global_rank == src_global_rank @@ -361,18 +361,14 @@ class NcclManagerTest : public ::testing::Test { ? nullptr : &test_case->outs[global_rank]; auto participant = absl::make_unique<NcclManager::Participant>( - device->executor(), stream, event_mgr, device->gpu_id(), input, - output, global_rank, this->CreateDoneCallback(test_case.get())); + device->executor(), stream, info, input, output, global_rank, + this->CreateDoneCallback(test_case.get())); if (global_rank == src_global_rank) { - VLOG(1) << "AddBroadcastSend node " << node << " global_rank " - << global_rank; node_states[node].nccl_manager.AddBroadcastSend( std::move(participant), {collective_key, num_ranks_per_node, num_global_ranks, communicator_key, src_global_rank}); } else { - VLOG(1) << "AddBroadcastRecv node " << node << " global_rank " - << global_rank; node_states[node].nccl_manager.AddBroadcastRecv( std::move(participant), {collective_key, num_ranks_per_node, num_global_ranks, @@ -442,11 +438,11 @@ TYPED_TEST(NcclManagerTest, BasicSumReduction) { for (int rank = 0; rank < num_ranks; ++rank) { auto* device = this->GetDevice(rank); VLOG(2) << "rank " << rank << " device " << device->name(); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique<NcclManager::Participant>( - device->executor(), stream, event_mgr, device->gpu_id(), - &test_case->ins[rank], &test_case->outs[rank], /*global_rank=*/-1, + device->executor(), stream, info, &test_case->ins[rank], + &test_case->outs[rank], /*global_rank=*/-1, this->CreateDoneCallback(test_case.get())); NcclManager::instance()->AddToAllReduce( std::move(participant), @@ -508,12 +504,12 @@ TYPED_TEST(NcclManagerTest, MultipleCallers) { case_and_rank.pop_back(); } auto* device = this->GetDevice(rank); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; typename TestFixture::TestCase* test_case = test_cases[test_num].get(); auto participant = absl::make_unique<NcclManager::Participant>( - device->executor(), stream, event_mgr, device->gpu_id(), - &test_case->ins[rank], &test_case->outs[rank], /*global_rank=*/-1, + device->executor(), stream, info, &test_case->ins[rank], + &test_case->outs[rank], /*global_rank=*/-1, this->CreateDoneCallback(test_case)); NcclManager::instance()->AddToAllReduce( std::move(participant), @@ -551,11 +547,11 @@ TYPED_TEST(NcclManagerTest, BasicAllGather) { for (int rank = 0; rank < num_ranks; ++rank) { auto* device = this->GetDevice(rank); VLOG(2) << "rank " << rank << " device " << device->name(); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique<NcclManager::Participant>( - device->executor(), stream, event_mgr, device->gpu_id(), - &test_case->ins[rank], &test_case->outs[rank], rank, + device->executor(), stream, info, &test_case->ins[rank], + &test_case->outs[rank], rank, this->CreateDoneCallback(test_case.get())); NcclManager::instance()->AddToAllGather( std::move(participant), @@ -585,7 +581,12 @@ TYPED_TEST(NcclManagerTest, InPlaceBroadcast) { // Test broadcast with increasing ranks. TYPED_TEST(NcclManagerTest, BroadcastWithDifferentRanks) { - for (int num_ranks = 4; num_ranks <= 8; ++num_ranks) { +#if TENSORFLOW_USE_ROCM + for (int num_ranks = 1; num_ranks <= 4; ++num_ranks) +#else + for (int num_ranks = 4; num_ranks <= 8; ++num_ranks) +#endif + { const int src_rank = static_cast<int>(random::New64() % num_ranks); for (int in_place_idx = 0; in_place_idx <= 1; ++in_place_idx) { const bool in_place = in_place_idx == 0; @@ -603,12 +604,14 @@ TEST(NcclManagerTest, CommunicatorKey) { EXPECT_EQ(communicator_key.size(), NCCL_UNIQUE_ID_BYTES); } +#if !TENSORFLOW_USE_ROCM // This test creates `num_nodes` NcclManagers to simulate a multi-node // environment. It works on a single node and reuses GPUs. It enqueues NCCL // kernels on separate stream per rank. TYPED_TEST(NcclManagerTest, MultiNode) { this->RunMultiNodeAllReduceTest(/*num_nodes=*/2, /*num_ranks_per_node=*/4); } +#endif // Tests that specifying `communicator_key` with a single node NCCL collective // works well. @@ -618,9 +621,15 @@ TYPED_TEST(NcclManagerTest, MultiNodeSingle) { // Multi-node broadcast. TYPED_TEST(NcclManagerTest, MultiNodeBroadcast) { +#if TENSORFLOW_USE_ROCM + this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4, + /*src_node=*/0, /*src_local_rank=*/3, + /*in_place=*/true); +#else this->RunMultiNodeBroadcastTest(/*num_nodes=*/4, /*num_ranks_per_node=*/8, /*src_node=*/2, /*src_local_rank=*/3, /*in_place=*/true); +#endif } // Checks that we return error status if a collective_key is used for different @@ -633,11 +642,11 @@ TYPED_TEST(NcclManagerTest, ConsistentCollectiveType) { TensorShape({2, 3}), 0.0f)); for (int rank = 0; rank < num_ranks; ++rank) { auto* device = this->GetDevice(rank); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique<NcclManager::Participant>( - device->executor(), stream, event_mgr, device->gpu_id(), - &test_case->ins[rank], &test_case->outs[rank], /*global_rank=*/-1, + device->executor(), stream, info, &test_case->ins[rank], + &test_case->outs[rank], /*global_rank=*/-1, this->CreateDoneCallback(test_case.get())); if (rank == 0) { NcclManager::instance()->AddToAllReduce(std::move(participant), @@ -670,11 +679,11 @@ TYPED_TEST(NcclManagerTest, ConsistentCommunicatorKey) { TensorShape({2, 3}), 0.0f)); for (int rank = 0; rank < num_ranks; ++rank) { auto* device = this->GetDevice(rank); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique<NcclManager::Participant>( - device->executor(), stream, event_mgr, device->gpu_id(), - &test_case->ins[rank], &test_case->outs[rank], /*global_rank=*/-1, + device->executor(), stream, info, &test_case->ins[rank], + &test_case->outs[rank], /*global_rank=*/-1, this->CreateDoneCallback(test_case.get())); NcclManager::instance()->AddToAllReduce( std::move(participant), @@ -699,12 +708,12 @@ TYPED_TEST(NcclManagerTest, ConsistentNumberOfDevices) { TensorShape({2, 3}), 0.0f)); for (int rank = 0; rank < num_ranks; ++rank) { auto* device = this->GetDevice(rank); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; int num_devices = rank == 0 ? num_ranks : num_ranks + 1; auto participant = absl::make_unique<NcclManager::Participant>( - device->executor(), stream, event_mgr, device->gpu_id(), - &test_case->ins[rank], &test_case->outs[rank], /*global_rank=*/-1, + device->executor(), stream, info, &test_case->ins[rank], + &test_case->outs[rank], /*global_rank=*/-1, this->CreateDoneCallback(test_case.get())); NcclManager::instance()->AddToAllReduce(std::move(participant), {"bad_coll_type", @@ -728,11 +737,10 @@ TYPED_TEST(NcclManagerTest, BroadcastNoSource) { /*src_rank=*/-1, false)); for (int rank = 0; rank < num_ranks; ++rank) { auto* device = this->GetDevice(rank); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique<NcclManager::Participant>( - device->executor(), stream, event_mgr, device->gpu_id(), nullptr, - &test_case->outs[rank], rank, + device->executor(), stream, info, nullptr, &test_case->outs[rank], rank, this->CreateDoneCallback(test_case.get())); NcclManager::instance()->AddBroadcastRecv(std::move(participant), {"bcast_no_send", @@ -755,11 +763,11 @@ TYPED_TEST(NcclManagerTest, BroadcastMultipleSends) { /*src_rank=*/-1, false)); for (int rank = 0; rank < num_ranks; ++rank) { auto* device = this->GetDevice(rank); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique<NcclManager::Participant>( - device->executor(), stream, event_mgr, device->gpu_id(), - &test_case->outs[rank], &test_case->outs[rank], rank, + device->executor(), stream, info, &test_case->outs[rank], + &test_case->outs[rank], rank, this->CreateDoneCallback(test_case.get())); NcclManager::instance()->AddBroadcastSend(std::move(participant), {"bcast_multiple_send", @@ -783,11 +791,11 @@ TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) { /*src_rank=*/-1, false)); for (int rank = 0; rank < num_ranks; ++rank) { auto* device = this->GetDevice(rank); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* info = device->tensorflow_gpu_device_info(); auto* stream = device->tensorflow_gpu_device_info()->stream; auto participant = absl::make_unique<NcclManager::Participant>( - device->executor(), stream, event_mgr, device->gpu_id(), - &test_case->outs[rank], &test_case->outs[rank], rank, + device->executor(), stream, info, &test_case->outs[rank], + &test_case->outs[rank], rank, this->CreateDoneCallback(test_case.get())); NcclManager::instance()->AddBroadcastRecv(std::move(participant), {"bcast_inconsistent_source",