Fix collective_nccl_test after breakage in cl/355646163.

PiperOrigin-RevId: 355690665
Change-Id: I5ad78388c4d0e2b0be7daad1d6a136c1a8444c3d
This commit is contained in:
Ayush Dubey 2021-02-04 12:50:07 -08:00 committed by TensorFlower Gardener
parent a0bd36e7f4
commit 6dd7d00588

View File

@ -88,10 +88,12 @@ class NcclTestBase : public ::testing::Test {
nccl_communicator_(MaybeCreateNcclCommunicator()),
work_queue_(std::make_shared<UnboundedWorkQueue>(
Env::Default(), "collective_executor")),
col_exec_(nullptr) {}
col_exec_(nullptr),
col_params_(nullptr) {}
~NcclTestBase() override {
if (col_exec_) col_exec_->Unref();
if (col_params_) col_params_->Unref();
}
void SetUp() {
@ -126,24 +128,25 @@ class NcclTestBase : public ::testing::Test {
/*gpu_ring_order=*/nullptr, work_queue_);
// Initialize collective params.
col_params_.name = "test_nccl_collective_op";
col_params_ = new CollectiveParams();
col_params_->name = "test_nccl_collective_op";
const int group_key = num_ranks;
col_params_.group.group_key = group_key;
col_params_.group.device_type = DEVICE_GPU;
col_params_.group.group_size = num_ranks;
col_params_.instance.instance_key = instance_key;
col_params_.instance.type = collective_type_;
col_params_.instance.data_type = DT_FLOAT;
col_params_.instance.impl_details.collective_name = collective_name_;
col_params_->group.group_key = group_key;
col_params_->group.device_type = DEVICE_GPU;
col_params_->group.group_size = num_ranks;
col_params_->instance.instance_key = instance_key;
col_params_->instance.type = collective_type_;
col_params_->instance.data_type = DT_FLOAT;
col_params_->instance.impl_details.collective_name = collective_name_;
const string task_name = "/job:worker/replica:0/task:0";
col_params_.group.num_devices_per_task[task_name] = num_ranks;
col_params_->group.num_devices_per_task[task_name] = num_ranks;
for (int rank = 0; rank < num_ranks; ++rank) {
col_params_.group.device_names.push_back(device_names[rank % num_gpus]);
col_params_.group.task_names.push_back(task_name);
col_params_->group.device_names.push_back(device_names[rank % num_gpus]);
col_params_->group.task_names.push_back(task_name);
}
for (int rank = 0; rank < num_ranks; ++rank) {
instances_.push_back(absl::make_unique<DeviceInstance>(
rank, col_params_.group.device_names[rank], this));
rank, col_params_->group.device_names[rank], this));
}
}
@ -244,18 +247,23 @@ class NcclTestBase : public ::testing::Test {
class DeviceInstance {
public:
DeviceInstance(int rank, const string& device_name, NcclTestBase* parent)
: parent_(parent), device_name_(device_name), rank_(rank) {
: parent_(parent),
device_name_(device_name),
rank_(rank),
col_params_(new CollectiveParams()) {
TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(device_name_, &device_))
<< "Could not find device " << device_name_ << " existing devices "
<< parent_->dev_mgr_->DebugString();
merge_op_ = GetAdd(device_);
final_op_ = GetDiv(device_);
col_params_.name = parent_->col_params_.name;
col_params_.default_rank = rank;
col_params_.group = parent_->col_params_.group;
col_params_.instance = parent->col_params_.instance;
col_params_->name = parent_->col_params_->name;
col_params_->default_rank = rank;
col_params_->group = parent_->col_params_->group;
col_params_->instance = parent->col_params_->instance;
}
~DeviceInstance() { col_params_->Unref(); }
void InitTensor(DataType dtype, const TensorShape& shape,
const std::function<void(Tensor*)>& init_f) {
input_ =
@ -304,7 +312,7 @@ class NcclTestBase : public ::testing::Test {
AllocatorAttributes generic_alloc_attr;
op_params.output_attr_array = &generic_alloc_attr;
std::unique_ptr<OpKernel> op =
parent_->GetCollectiveReduceOpKernel(col_params_, &input_, device_);
parent_->GetCollectiveReduceOpKernel(*col_params_, &input_, device_);
op_params.op_kernel = op.get();
OpKernelContext ctx(&op_params, 1);
// We never actually execute the kernel, so we need to do the output
@ -316,7 +324,7 @@ class NcclTestBase : public ::testing::Test {
// Run the all-reduce.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
strings::StrCat(col_params_->instance.instance_key, ":0:0");
auto* reducer = new NcclReducer();
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->nccl_communicator_.get(),
@ -340,7 +348,7 @@ class NcclTestBase : public ::testing::Test {
void RunBroadcast() {
VLOG(2) << "RunBroadcast name " << parent_->collective_name_ << " rank "
<< col_params_.default_rank;
<< col_params_->default_rank;
// Prepare an OpKernelContext.
OpKernelContext::Params op_params;
PrepareDeviceContext(&op_params);
@ -348,13 +356,13 @@ class NcclTestBase : public ::testing::Test {
// Run broadcast.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
strings::StrCat(col_params_->instance.instance_key, ":0:0");
auto* broadcaster = new NcclBroadcaster();
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->nccl_communicator_.get(),
parent_->dev_mgr_.get(),
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
/*input=*/col_params_.is_source ? &input_ : nullptr,
/*input=*/col_params_->is_source ? &input_ : nullptr,
/*output=*/&input_);
TF_CHECK_OK(broadcaster->InitializeCollectiveContext(col_ctx));
Notification note;
@ -373,7 +381,7 @@ class NcclTestBase : public ::testing::Test {
void RunGather() {
VLOG(2) << "RunGather name " << parent_->collective_name_ << " rank "
<< col_params_.default_rank;
<< col_params_->default_rank;
// Prepare an OpKernelContext.
OpKernelContext::Params op_params;
PrepareDeviceContext(&op_params);
@ -383,13 +391,13 @@ class NcclTestBase : public ::testing::Test {
// different shape.
auto output_shape = input_.shape();
output_shape.set_dim(
0, output_shape.dim_size(0) * col_params_.group.group_size);
0, output_shape.dim_size(0) * col_params_->group.group_size);
output_ = Tensor(device_->GetAllocator(AllocatorAttributes()), DT_FLOAT,
output_shape);
// Run gather.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
strings::StrCat(col_params_->instance.instance_key, ":0:0");
auto* gatherer = new NcclGatherer();
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->nccl_communicator_.get(),
@ -415,7 +423,7 @@ class NcclTestBase : public ::testing::Test {
Tensor input_;
Tensor output_;
Device* device_;
CollectiveParams col_params_;
CollectiveParams* col_params_;
std::unique_ptr<OpKernel> merge_op_;
std::unique_ptr<OpKernel> final_op_;
Status status_;
@ -430,7 +438,7 @@ class NcclTestBase : public ::testing::Test {
CollectiveExecutor* col_exec_;
std::unique_ptr<DeviceMgr> dev_mgr_;
std::vector<std::unique_ptr<DeviceInstance>> instances_;
CollectiveParams col_params_;
CollectiveParams* col_params_;
mutex mu_;
int32 op_counter_ TF_GUARDED_BY(mu_) = 0;
};
@ -463,8 +471,8 @@ class NcclReducerTest : public NcclTestBase {
}
void InitDevice(DeviceInstance* di) override {
di->col_params_.merge_op = di->merge_op_.get();
di->col_params_.final_op = di->final_op_.get();
di->col_params_->merge_op = di->merge_op_.get();
di->col_params_->final_op = di->final_op_.get();
}
void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunReduce(); }
@ -493,8 +501,8 @@ class NcclBroadcasterTest : public NcclTestBase {
}
void InitDevice(DeviceInstance* di) override {
di->col_params_.source_rank = source_rank_;
di->col_params_.is_source = di->col_params_.default_rank == source_rank_;
di->col_params_->source_rank = source_rank_;
di->col_params_->is_source = di->col_params_->default_rank == source_rank_;
}
void RunCollectiveOnDevice(DeviceInstance* di) override {