Add a collective broadcast implementation using NCCL.
The implementation extends `NcclBase` to `NcclBroadcaster`, similar to `NcclReducer`. This change also refactors collective NCCL tests. PiperOrigin-RevId: 261021538
This commit is contained in:
parent
4bdc9d0ee6
commit
9bdc9dbf52
@ -56,10 +56,14 @@ void CollectiveParamResolverLocal::CompleteGroupAsync(
|
||||
}
|
||||
|
||||
namespace {
|
||||
string GetCollectiveName(const CollectiveParams* cp, bool nccl) {
|
||||
const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
|
||||
switch (cp->instance.type) {
|
||||
case BROADCAST_COLLECTIVE:
|
||||
return "HierarchicalTreeBroadcast";
|
||||
if (nccl) {
|
||||
return "NcclBroadcast";
|
||||
} else {
|
||||
return "HierarchicalTreeBroadcast";
|
||||
}
|
||||
|
||||
case REDUCTION_COLLECTIVE: {
|
||||
if (nccl) {
|
||||
@ -96,8 +100,8 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||
|
||||
// Initialize group runtime details.
|
||||
CollectiveImplementationInterface* col_impl;
|
||||
// TODO(b/128853131,b/132707282): Remove NCCL special case when we have
|
||||
// NCCL implementations for all collectives.
|
||||
// TODO(b/128853131): Remove NCCL special case when we have NCCL
|
||||
// implementations for all collectives.
|
||||
status = CollectiveRegistry::LookupParamResolverInstance(
|
||||
nccl_ ? "NcclReduce" : GetCollectiveName(cp, /*nccl=*/false),
|
||||
&col_impl);
|
||||
|
@ -200,6 +200,8 @@ tf_kernel_library(
|
||||
srcs = if_nccl([
|
||||
"collective_nccl.h",
|
||||
"collective_nccl.cc",
|
||||
"collective_nccl_broadcaster.h",
|
||||
"collective_nccl_broadcaster.cc",
|
||||
"collective_nccl_reducer.h",
|
||||
"collective_nccl_reducer.cc",
|
||||
]),
|
||||
@ -216,9 +218,9 @@ tf_kernel_library(
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "collective_nccl_reducer_test",
|
||||
name = "collective_nccl_test",
|
||||
size = "small",
|
||||
srcs = ["collective_nccl_reducer_test.cc"],
|
||||
srcs = ["collective_nccl_test.cc"],
|
||||
tags = tf_cuda_tests_tags() + ["no_cuda_on_cpu_tap"],
|
||||
deps = [
|
||||
"//tensorflow/core:all_kernels",
|
||||
|
81
tensorflow/core/kernels/collective_nccl_broadcaster.cc
Normal file
81
tensorflow/core/kernels/collective_nccl_broadcaster.cc
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/collective_nccl_broadcaster.h"
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/core/common_runtime/collective_util.h"
|
||||
#include "tensorflow/core/nccl/nccl_manager.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void NcclBroadcaster::Run(StatusCallback done) {
|
||||
auto* compute_stream = col_ctx_->op_ctx->op_device_context()->stream();
|
||||
auto* gpu_info = col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
|
||||
const int num_global_devices = col_params_->group.group_size;
|
||||
const int num_local_devices = col_params_->instance.num_devices_per_task.at(
|
||||
col_params_->instance.task_names[col_params_->default_rank]);
|
||||
string nccl_collective_key =
|
||||
NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id);
|
||||
auto participant = absl::make_unique<NcclManager::Participant>(
|
||||
compute_stream->parent(), compute_stream, gpu_info->event_mgr,
|
||||
gpu_info->gpu_id, col_ctx_->input, col_ctx_->output,
|
||||
col_params_->default_rank, std::move(done));
|
||||
VLOG(1)
|
||||
<< "NcclBroadcast calling NcclManager::AddBroadcastSend/Recv num_tasks "
|
||||
<< col_params_->group.num_tasks << " current task "
|
||||
<< col_params_->instance.task_names[col_params_->default_rank]
|
||||
<< " num local devices " << num_local_devices << " num global devices "
|
||||
<< num_global_devices << " rank " << col_params_->default_rank
|
||||
<< " device " << col_ctx_->device_name << " instance "
|
||||
<< col_params_->instance.instance_key << " source "
|
||||
<< col_params_->is_source;
|
||||
if (col_params_->is_source) {
|
||||
NcclManager::instance()->AddBroadcastSend(
|
||||
std::move(participant),
|
||||
{std::move(nccl_collective_key), num_local_devices, num_global_devices,
|
||||
col_params_->group.runtime_details.communicator_key});
|
||||
} else {
|
||||
NcclManager::instance()->AddBroadcastRecv(
|
||||
std::move(participant),
|
||||
{std::move(nccl_collective_key), num_local_devices, num_global_devices,
|
||||
col_params_->group.runtime_details.communicator_key});
|
||||
}
|
||||
{
|
||||
// `WaitForDependencies` may block if the collective instances on which this
|
||||
// op depends have not yet launched. When this function returns, this op is
|
||||
// ready to go.
|
||||
profiler::TraceMe activity("WaitForDependencies",
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
col_ctx_->col_exec->WaitForDependencies(*col_params_);
|
||||
NcclManager::instance()->SignalMultiNodeReady(nccl_collective_key);
|
||||
}
|
||||
{
|
||||
// When all devices at this worker have called `SignalMultiNodeReady`, the
|
||||
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
|
||||
// implementation of `Launched` keeps track of the number of devices that
|
||||
// have launched.
|
||||
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
|
||||
col_ctx_->col_exec->Launched(*col_params_);
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_COLLECTIVE(NcclBroadcast, NcclBroadcaster);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
35
tensorflow/core/kernels/collective_nccl_broadcaster.h
Normal file
35
tensorflow/core/kernels/collective_nccl_broadcaster.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_BROADCASTER_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_BROADCASTER_H_
|
||||
|
||||
#include "tensorflow/core/kernels/collective_nccl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
#ifdef GOOGLE_CUDA
|
||||
|
||||
class NcclBroadcaster : public NcclBase {
|
||||
public:
|
||||
NcclBroadcaster() : NcclBase(BROADCAST_COLLECTIVE, "NcclBroadcast") {}
|
||||
~NcclBroadcaster() override = default;
|
||||
|
||||
// Hands off broadcast to NcclManager.
|
||||
void Run(StatusCallback done) override;
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_BROADCASTER_H_
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/core/kernels/collective_nccl_reducer.h"
|
||||
#include "tensorflow/core/kernels/collective_nccl.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/common_runtime/test_collective_executor_mgr.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
@ -32,6 +33,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/collective_nccl_broadcaster.h"
|
||||
#include "tensorflow/core/kernels/collective_nccl_reducer.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -70,9 +73,13 @@ std::unique_ptr<OpKernel> GetDiv(DeviceBase* device) {
|
||||
return GetKernel(node_def, device);
|
||||
}
|
||||
|
||||
class NcclReducerTest : public ::testing::Test {
|
||||
class NcclTestBase : public ::testing::Test {
|
||||
protected:
|
||||
~NcclReducerTest() override {
|
||||
class DeviceInstance;
|
||||
|
||||
NcclTestBase(CollectiveType collective_type, const string& collective_name)
|
||||
: collective_type_(collective_type), collective_name_(collective_name) {}
|
||||
~NcclTestBase() override {
|
||||
if (col_exec_) col_exec_->Unref();
|
||||
}
|
||||
|
||||
@ -92,7 +99,7 @@ class NcclReducerTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
void Init(int num_ranks) {
|
||||
void Init(const int num_ranks, const int instance_key) {
|
||||
setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
|
||||
setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */);
|
||||
InitGPUDevices();
|
||||
@ -115,15 +122,14 @@ class NcclReducerTest : public ::testing::Test {
|
||||
|
||||
// Initialize collective params.
|
||||
col_params_.name = "test_nccl_collective_op";
|
||||
const int group_key = 5;
|
||||
const int group_key = num_ranks;
|
||||
col_params_.group.group_key = group_key;
|
||||
col_params_.group.device_type = DEVICE_GPU;
|
||||
col_params_.group.group_size = num_ranks;
|
||||
const int instance_key = 23;
|
||||
col_params_.instance.instance_key = instance_key;
|
||||
col_params_.instance.type = REDUCTION_COLLECTIVE;
|
||||
col_params_.instance.type = collective_type_;
|
||||
col_params_.instance.data_type = DT_FLOAT;
|
||||
col_params_.instance.impl_details.collective_name = "NcclReduce";
|
||||
col_params_.instance.impl_details.collective_name = collective_name_;
|
||||
const string task_name = "/job:worker/replica:0/task:0";
|
||||
col_params_.instance.num_devices_per_task[task_name] = num_ranks;
|
||||
for (int rank = 0; rank < num_ranks; ++rank) {
|
||||
@ -137,14 +143,28 @@ class NcclReducerTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
void Reduce() {
|
||||
// Initialize `input` tensor at rank `rank`.
|
||||
virtual void InitInput(Tensor* input, const int rank) = 0;
|
||||
|
||||
// Initialize `expected` output at all `num_ranks` ranks.
|
||||
virtual void InitExpected(std::vector<float>* expected,
|
||||
const int tensor_length, const int num_ranks) = 0;
|
||||
|
||||
// Initialize device `di` specific to the collective op.
|
||||
virtual void InitDevice(DeviceInstance* di) = 0;
|
||||
|
||||
// Run collective op on device `di`.
|
||||
virtual void RunCollectiveOnDevice(DeviceInstance* di) = 0;
|
||||
|
||||
void RunCollective() {
|
||||
int done = 0;
|
||||
mutex done_mu;
|
||||
condition_variable done_cv;
|
||||
for (const auto& instance : instances_) {
|
||||
DeviceInstance* di = instance.get();
|
||||
SchedClosure([di, &done, &done_mu, &done_cv] {
|
||||
di->DoReduce();
|
||||
InitDevice(di);
|
||||
SchedClosure([this, di, &done, &done_mu, &done_cv] {
|
||||
RunCollectiveOnDevice(di);
|
||||
mutex_lock l(done_mu);
|
||||
++done;
|
||||
done_cv.notify_all();
|
||||
@ -155,35 +175,32 @@ class NcclReducerTest : public ::testing::Test {
|
||||
while (done < instances_.size()) done_cv.wait(l);
|
||||
}
|
||||
|
||||
void RunTest(int num_ranks, int tensor_length) {
|
||||
Init(num_ranks);
|
||||
void RunTest(int num_ranks, int tensor_length, int instance_key) {
|
||||
Init(num_ranks, instance_key);
|
||||
std::vector<float> expected(tensor_length, 0.0);
|
||||
InitExpected(&expected, tensor_length, num_ranks);
|
||||
for (int rank = 0; rank < num_ranks; ++rank) {
|
||||
DeviceInstance* instance = instances_[rank].get();
|
||||
instance->InitTensor(DT_FLOAT, TensorShape({tensor_length}),
|
||||
[&expected, rank](Tensor* t) {
|
||||
for (size_t i = 0; i < t->NumElements(); ++i) {
|
||||
float value = pow(10, rank) * i;
|
||||
t->flat<float>()(i) = value;
|
||||
expected[i] += value;
|
||||
}
|
||||
});
|
||||
[this, rank](Tensor* t) { InitInput(t, rank); });
|
||||
}
|
||||
Reduce();
|
||||
RunCollective();
|
||||
// Confirm that every rank computed the same correct value.
|
||||
for (int i = 0; i < tensor_length; ++i) {
|
||||
expected[i] /= num_ranks;
|
||||
}
|
||||
for (int rank = 0; rank < instances_.size(); ++rank) {
|
||||
TF_ASSERT_OK(instances_[rank]->status_);
|
||||
Tensor* dev_tensor = &instances_[rank]->tensor_;
|
||||
VLOG(2) << "rank " << rank << " output " << dev_tensor << " buf "
|
||||
<< DMAHelper::base(dev_tensor);
|
||||
Tensor actual(DT_FLOAT, TensorShape({tensor_length}));
|
||||
Notification note;
|
||||
Device* dev = instances_[rank]->device_;
|
||||
auto* dev_info = dev->tensorflow_gpu_device_info();
|
||||
dev_info->default_context->CopyDeviceTensorToCPU(
|
||||
dev_tensor, /*tensor_name=*/"", dev, &actual,
|
||||
[¬e](const Status&) { note.Notify(); });
|
||||
[¬e](const Status& s) {
|
||||
TF_CHECK_OK(s);
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
for (int i = 0; i < tensor_length; ++i) {
|
||||
EXPECT_FLOAT_EQ(expected[i], actual.template flat<float>()(i))
|
||||
@ -192,14 +209,12 @@ class NcclReducerTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OpKernel> GetCollectiveReduce(const CollectiveParams& params,
|
||||
Tensor* input,
|
||||
DeviceBase* device) {
|
||||
std::unique_ptr<OpKernel> GetCollectiveReduceOpKernel(
|
||||
const CollectiveParams& params, Tensor* input, DeviceBase* device) {
|
||||
mutex_lock l(mu_);
|
||||
NodeDef node_def;
|
||||
NodeDefBuilder builder(
|
||||
strings::StrCat("collective_reduce_", reduce_counter_++),
|
||||
"CollectiveReduce");
|
||||
NodeDefBuilder builder(strings::StrCat("collective_reduce_", op_counter_++),
|
||||
"CollectiveReduce");
|
||||
TF_CHECK_OK(
|
||||
builder.Attr("T", params.instance.data_type)
|
||||
.Attr("merge_op", "Add")
|
||||
@ -215,7 +230,7 @@ class NcclReducerTest : public ::testing::Test {
|
||||
|
||||
class DeviceInstance {
|
||||
public:
|
||||
DeviceInstance(int rank, const string& device_name, NcclReducerTest* parent)
|
||||
DeviceInstance(int rank, const string& device_name, NcclTestBase* parent)
|
||||
: parent_(parent), device_name_(device_name), rank_(rank) {
|
||||
TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(device_name_, &device_))
|
||||
<< "Could not find device " << device_name_ << " existing devices "
|
||||
@ -238,26 +253,16 @@ class NcclReducerTest : public ::testing::Test {
|
||||
auto* dev_info = device_->tensorflow_gpu_device_info();
|
||||
Notification note;
|
||||
dev_info->default_context->CopyCPUTensorToDevice(
|
||||
&cpu_tensor, device_, &tensor_,
|
||||
[¬e](const Status&) { note.Notify(); });
|
||||
&cpu_tensor, device_, &tensor_, [¬e](const Status& s) {
|
||||
TF_CHECK_OK(s);
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
}
|
||||
|
||||
void DoReduce() {
|
||||
col_params_.merge_op = GetAdd(device_);
|
||||
col_params_.final_op = GetDiv(device_);
|
||||
|
||||
// Prepare an OpKernelContext.
|
||||
OpKernelContext::Params op_params;
|
||||
op_params.step_id = kStepId;
|
||||
op_params.device = device_;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
inputs.push_back(TensorValue(&tensor_));
|
||||
op_params.inputs = &inputs;
|
||||
gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
|
||||
{AllocatorAttributes()});
|
||||
op_params.input_alloc_attrs = &input_aa;
|
||||
gtl::InlinedVector<DeviceContext*, 4> input_dc;
|
||||
void PrepareDeviceContext(OpKernelContext::Params* params) {
|
||||
params->step_id = kStepId;
|
||||
params->device = device_;
|
||||
DeviceContext* dev_ctx = nullptr;
|
||||
auto* dev_info = device_->tensorflow_gpu_device_info();
|
||||
if (dev_info) {
|
||||
@ -266,18 +271,32 @@ class NcclReducerTest : public ::testing::Test {
|
||||
} else {
|
||||
dev_ctx = new DeviceContext;
|
||||
}
|
||||
input_dc.push_back(dev_ctx);
|
||||
params->op_device_context = dev_ctx;
|
||||
}
|
||||
|
||||
void RunReduce() {
|
||||
// Prepare an OpKernelContext.
|
||||
OpKernelContext::Params op_params;
|
||||
PrepareDeviceContext(&op_params);
|
||||
|
||||
// Prepare inputs and outputs to OpKernel.
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
inputs.push_back(TensorValue(&tensor_));
|
||||
op_params.inputs = &inputs;
|
||||
gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
|
||||
{AllocatorAttributes()});
|
||||
op_params.input_alloc_attrs = &input_aa;
|
||||
gtl::InlinedVector<DeviceContext*, 4> input_dc;
|
||||
input_dc.push_back(op_params.op_device_context);
|
||||
op_params.input_device_contexts = &input_dc;
|
||||
op_params.op_device_context = dev_ctx;
|
||||
int forward_from = 0;
|
||||
op_params.forward_from_array = &forward_from;
|
||||
AllocatorAttributes generic_alloc_attr;
|
||||
op_params.output_attr_array = &generic_alloc_attr;
|
||||
std::unique_ptr<OpKernel> op =
|
||||
parent_->GetCollectiveReduce(col_params_, &tensor_, device_);
|
||||
parent_->GetCollectiveReduceOpKernel(col_params_, &tensor_, device_);
|
||||
op_params.op_kernel = op.get();
|
||||
OpKernelContext ctx(&op_params, 1);
|
||||
|
||||
// We never actually execute the kernel, so we need to do the output
|
||||
// allocation it would do, ourselves.
|
||||
Tensor* output_tensor_ptr = nullptr;
|
||||
@ -285,25 +304,57 @@ class NcclReducerTest : public ::testing::Test {
|
||||
&output_tensor_ptr));
|
||||
CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
|
||||
|
||||
// Prepare a NcclReducer instance.
|
||||
// Run the all-reduce.
|
||||
string exec_key =
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
NcclReducer reducer;
|
||||
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
&ctx, &op_params, col_params_, exec_key,
|
||||
kStepId, &tensor_, &tensor_);
|
||||
/*OpKernelContext=*/&ctx, &op_params,
|
||||
col_params_, exec_key, kStepId,
|
||||
/*input=*/&tensor_, /*output=*/&tensor_);
|
||||
TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
|
||||
|
||||
// Run the all-reduce.
|
||||
reducer.Run([this](Status s) { status_ = s; });
|
||||
Notification note;
|
||||
reducer.Run([this, ¬e](Status s) {
|
||||
status_ = s;
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
if (status_.ok()) {
|
||||
CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
|
||||
}
|
||||
|
||||
dev_ctx->Unref();
|
||||
op_params.op_device_context->Unref();
|
||||
}
|
||||
|
||||
NcclReducerTest* parent_;
|
||||
void RunBroadcast() {
|
||||
VLOG(2) << "RunBroadcast name " << parent_->collective_name_ << " rank "
|
||||
<< col_params_.default_rank;
|
||||
// Prepare an OpKernelContext.
|
||||
OpKernelContext::Params op_params;
|
||||
PrepareDeviceContext(&op_params);
|
||||
OpKernelContext ctx(&op_params, 1);
|
||||
|
||||
// Run broadcast.
|
||||
string exec_key =
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
NcclBroadcaster broadcaster;
|
||||
CollectiveContext col_ctx(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
|
||||
/*input=*/col_params_.is_source ? &tensor_ : nullptr,
|
||||
/*output=*/&tensor_);
|
||||
TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx));
|
||||
Notification note;
|
||||
broadcaster.Run([this, ¬e](Status s) {
|
||||
status_ = s;
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
|
||||
op_params.op_device_context->Unref();
|
||||
}
|
||||
|
||||
NcclTestBase* parent_;
|
||||
string device_name_;
|
||||
int rank_;
|
||||
Tensor tensor_;
|
||||
@ -312,6 +363,8 @@ class NcclReducerTest : public ::testing::Test {
|
||||
Status status_;
|
||||
};
|
||||
|
||||
CollectiveType collective_type_;
|
||||
const string collective_name_;
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> gpus_;
|
||||
TestCollectiveExecutorMgr col_exec_mgr_;
|
||||
CollectiveExecutor* col_exec_;
|
||||
@ -319,14 +372,110 @@ class NcclReducerTest : public ::testing::Test {
|
||||
std::vector<std::unique_ptr<DeviceInstance>> instances_;
|
||||
CollectiveParams col_params_;
|
||||
mutex mu_;
|
||||
int32 reduce_counter_ GUARDED_BY(mu_) = 0;
|
||||
int32 op_counter_ GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
TEST_F(NcclReducerTest, Test2Dev16Len) { RunTest(2, 16); }
|
||||
TEST_F(NcclReducerTest, Test4Dev16Len) { RunTest(4, 16); }
|
||||
TEST_F(NcclReducerTest, Test8Dev16Len) { RunTest(8, 16); }
|
||||
TEST_F(NcclReducerTest, Test8Dev128Len) { RunTest(8, 128); }
|
||||
TEST_F(NcclReducerTest, Test8Dev1045991Len) { RunTest(8, 1048576); }
|
||||
class NcclReducerTest : public NcclTestBase {
|
||||
protected:
|
||||
NcclReducerTest()
|
||||
: NcclTestBase(/*collective_type=*/REDUCTION_COLLECTIVE,
|
||||
/*collective_name=*/"NcclReduce") {}
|
||||
~NcclReducerTest() override = default;
|
||||
|
||||
void InitInput(Tensor* input, const int rank) override {
|
||||
for (size_t i = 0; i < input->NumElements(); ++i) {
|
||||
float value = pow(10, rank) * i;
|
||||
input->flat<float>()(i) = value;
|
||||
}
|
||||
}
|
||||
|
||||
void InitExpected(std::vector<float>* expected, const int tensor_length,
|
||||
const int num_ranks) override {
|
||||
expected->resize(tensor_length);
|
||||
for (int i = 0; i < tensor_length; ++i) {
|
||||
float expected_sum = 0.0;
|
||||
for (int rank = 0; rank < num_ranks; ++rank) {
|
||||
float value = pow(10, rank) * i;
|
||||
expected_sum += value;
|
||||
}
|
||||
(*expected)[i] = expected_sum / num_ranks;
|
||||
}
|
||||
}
|
||||
|
||||
void InitDevice(DeviceInstance* di) override {
|
||||
di->col_params_.merge_op = GetAdd(di->device_);
|
||||
di->col_params_.final_op = GetDiv(di->device_);
|
||||
}
|
||||
|
||||
void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunReduce(); }
|
||||
};
|
||||
|
||||
class NcclBroadcasterTest : public NcclTestBase {
|
||||
protected:
|
||||
NcclBroadcasterTest()
|
||||
: NcclTestBase(/*collective_type=*/BROADCAST_COLLECTIVE,
|
||||
/*collective_name=*/"NcclBroadcast") {}
|
||||
~NcclBroadcasterTest() override = default;
|
||||
|
||||
void InitInput(Tensor* input, const int rank) override {
|
||||
bool source = rank == source_rank_;
|
||||
for (size_t i = 0; i < input->NumElements(); ++i) {
|
||||
input->flat<float>()(i) = source ? static_cast<float>(i) : -1.0;
|
||||
}
|
||||
}
|
||||
|
||||
void InitExpected(std::vector<float>* expected, const int tensor_length,
|
||||
const int num_ranks) override {
|
||||
for (int i = 0; i < tensor_length; ++i) {
|
||||
(*expected)[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
void InitDevice(DeviceInstance* di) override {
|
||||
di->col_params_.source_rank = source_rank_;
|
||||
di->col_params_.is_source = di->col_params_.default_rank == source_rank_;
|
||||
}
|
||||
|
||||
void RunCollectiveOnDevice(DeviceInstance* di) override {
|
||||
di->RunBroadcast();
|
||||
}
|
||||
|
||||
int source_rank_ = 0;
|
||||
};
|
||||
|
||||
TEST_F(NcclReducerTest, Test2Dev16Len) {
|
||||
RunTest(/*num_ranks=*/2, /*tensor_length=*/16, /*instance_key=*/23);
|
||||
}
|
||||
TEST_F(NcclReducerTest, Test4Dev16Len) {
|
||||
RunTest(/*num_ranks=*/4, /*tensor_length=*/16, /*instance_key=*/23);
|
||||
}
|
||||
TEST_F(NcclReducerTest, Test8Dev16Len) {
|
||||
RunTest(/*num_ranks=*/8, /*tensor_length=*/16, /*instance_key=*/23);
|
||||
}
|
||||
TEST_F(NcclReducerTest, Test8Dev128Len) {
|
||||
RunTest(/*num_ranks=*/8, /*tensor_length=*/128, /*instance_key=*/23);
|
||||
}
|
||||
TEST_F(NcclReducerTest, Test8Dev1045991Len) {
|
||||
RunTest(/*num_ranks=*/8, /*tensor_length=*/1048576, /*instance_key=*/23);
|
||||
}
|
||||
|
||||
TEST_F(NcclBroadcasterTest, Test2Dev16LenSrc0) {
|
||||
RunTest(/*num_ranks=*/2, /*tensor_length=*/16, /*instance_key=*/23);
|
||||
}
|
||||
TEST_F(NcclBroadcasterTest, Test4Dev16LenSrc1) {
|
||||
source_rank_ = 1;
|
||||
RunTest(/*num_ranks=*/4, /*tensor_length=*/16, /*instance_key=*/23);
|
||||
}
|
||||
TEST_F(NcclBroadcasterTest, Test8Dev16LenSrc7) {
|
||||
source_rank_ = 7;
|
||||
RunTest(/*num_ranks=*/8, /*tensor_length=*/16, /*instance_key=*/23);
|
||||
}
|
||||
TEST_F(NcclBroadcasterTest, Test8Dev128LenSrc0) {
|
||||
RunTest(/*num_ranks=*/8, /*tensor_length=*/128, /*instance_key=*/24);
|
||||
}
|
||||
TEST_F(NcclBroadcasterTest, Test8Dev1045991LenSrc0) {
|
||||
RunTest(/*num_ranks=*/8, /*tensor_length=*/1048576, /*instance_key=*/23);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user