Fix collective_nccl_test after breakage in cl/355646163.
PiperOrigin-RevId: 355690665 Change-Id: I5ad78388c4d0e2b0be7daad1d6a136c1a8444c3d
This commit is contained in:
parent
a0bd36e7f4
commit
6dd7d00588
@ -88,10 +88,12 @@ class NcclTestBase : public ::testing::Test {
|
||||
nccl_communicator_(MaybeCreateNcclCommunicator()),
|
||||
work_queue_(std::make_shared<UnboundedWorkQueue>(
|
||||
Env::Default(), "collective_executor")),
|
||||
col_exec_(nullptr) {}
|
||||
col_exec_(nullptr),
|
||||
col_params_(nullptr) {}
|
||||
|
||||
~NcclTestBase() override {
|
||||
if (col_exec_) col_exec_->Unref();
|
||||
if (col_params_) col_params_->Unref();
|
||||
}
|
||||
|
||||
void SetUp() {
|
||||
@ -126,24 +128,25 @@ class NcclTestBase : public ::testing::Test {
|
||||
/*gpu_ring_order=*/nullptr, work_queue_);
|
||||
|
||||
// Initialize collective params.
|
||||
col_params_.name = "test_nccl_collective_op";
|
||||
col_params_ = new CollectiveParams();
|
||||
col_params_->name = "test_nccl_collective_op";
|
||||
const int group_key = num_ranks;
|
||||
col_params_.group.group_key = group_key;
|
||||
col_params_.group.device_type = DEVICE_GPU;
|
||||
col_params_.group.group_size = num_ranks;
|
||||
col_params_.instance.instance_key = instance_key;
|
||||
col_params_.instance.type = collective_type_;
|
||||
col_params_.instance.data_type = DT_FLOAT;
|
||||
col_params_.instance.impl_details.collective_name = collective_name_;
|
||||
col_params_->group.group_key = group_key;
|
||||
col_params_->group.device_type = DEVICE_GPU;
|
||||
col_params_->group.group_size = num_ranks;
|
||||
col_params_->instance.instance_key = instance_key;
|
||||
col_params_->instance.type = collective_type_;
|
||||
col_params_->instance.data_type = DT_FLOAT;
|
||||
col_params_->instance.impl_details.collective_name = collective_name_;
|
||||
const string task_name = "/job:worker/replica:0/task:0";
|
||||
col_params_.group.num_devices_per_task[task_name] = num_ranks;
|
||||
col_params_->group.num_devices_per_task[task_name] = num_ranks;
|
||||
for (int rank = 0; rank < num_ranks; ++rank) {
|
||||
col_params_.group.device_names.push_back(device_names[rank % num_gpus]);
|
||||
col_params_.group.task_names.push_back(task_name);
|
||||
col_params_->group.device_names.push_back(device_names[rank % num_gpus]);
|
||||
col_params_->group.task_names.push_back(task_name);
|
||||
}
|
||||
for (int rank = 0; rank < num_ranks; ++rank) {
|
||||
instances_.push_back(absl::make_unique<DeviceInstance>(
|
||||
rank, col_params_.group.device_names[rank], this));
|
||||
rank, col_params_->group.device_names[rank], this));
|
||||
}
|
||||
}
|
||||
|
||||
@ -244,18 +247,23 @@ class NcclTestBase : public ::testing::Test {
|
||||
class DeviceInstance {
|
||||
public:
|
||||
DeviceInstance(int rank, const string& device_name, NcclTestBase* parent)
|
||||
: parent_(parent), device_name_(device_name), rank_(rank) {
|
||||
: parent_(parent),
|
||||
device_name_(device_name),
|
||||
rank_(rank),
|
||||
col_params_(new CollectiveParams()) {
|
||||
TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(device_name_, &device_))
|
||||
<< "Could not find device " << device_name_ << " existing devices "
|
||||
<< parent_->dev_mgr_->DebugString();
|
||||
merge_op_ = GetAdd(device_);
|
||||
final_op_ = GetDiv(device_);
|
||||
col_params_.name = parent_->col_params_.name;
|
||||
col_params_.default_rank = rank;
|
||||
col_params_.group = parent_->col_params_.group;
|
||||
col_params_.instance = parent->col_params_.instance;
|
||||
col_params_->name = parent_->col_params_->name;
|
||||
col_params_->default_rank = rank;
|
||||
col_params_->group = parent_->col_params_->group;
|
||||
col_params_->instance = parent->col_params_->instance;
|
||||
}
|
||||
|
||||
~DeviceInstance() { col_params_->Unref(); }
|
||||
|
||||
void InitTensor(DataType dtype, const TensorShape& shape,
|
||||
const std::function<void(Tensor*)>& init_f) {
|
||||
input_ =
|
||||
@ -304,7 +312,7 @@ class NcclTestBase : public ::testing::Test {
|
||||
AllocatorAttributes generic_alloc_attr;
|
||||
op_params.output_attr_array = &generic_alloc_attr;
|
||||
std::unique_ptr<OpKernel> op =
|
||||
parent_->GetCollectiveReduceOpKernel(col_params_, &input_, device_);
|
||||
parent_->GetCollectiveReduceOpKernel(*col_params_, &input_, device_);
|
||||
op_params.op_kernel = op.get();
|
||||
OpKernelContext ctx(&op_params, 1);
|
||||
// We never actually execute the kernel, so we need to do the output
|
||||
@ -316,7 +324,7 @@ class NcclTestBase : public ::testing::Test {
|
||||
|
||||
// Run the all-reduce.
|
||||
string exec_key =
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
strings::StrCat(col_params_->instance.instance_key, ":0:0");
|
||||
auto* reducer = new NcclReducer();
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->nccl_communicator_.get(),
|
||||
@ -340,7 +348,7 @@ class NcclTestBase : public ::testing::Test {
|
||||
|
||||
void RunBroadcast() {
|
||||
VLOG(2) << "RunBroadcast name " << parent_->collective_name_ << " rank "
|
||||
<< col_params_.default_rank;
|
||||
<< col_params_->default_rank;
|
||||
// Prepare an OpKernelContext.
|
||||
OpKernelContext::Params op_params;
|
||||
PrepareDeviceContext(&op_params);
|
||||
@ -348,13 +356,13 @@ class NcclTestBase : public ::testing::Test {
|
||||
|
||||
// Run broadcast.
|
||||
string exec_key =
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
strings::StrCat(col_params_->instance.instance_key, ":0:0");
|
||||
auto* broadcaster = new NcclBroadcaster();
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->nccl_communicator_.get(),
|
||||
parent_->dev_mgr_.get(),
|
||||
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
|
||||
/*input=*/col_params_.is_source ? &input_ : nullptr,
|
||||
/*input=*/col_params_->is_source ? &input_ : nullptr,
|
||||
/*output=*/&input_);
|
||||
TF_CHECK_OK(broadcaster->InitializeCollectiveContext(col_ctx));
|
||||
Notification note;
|
||||
@ -373,7 +381,7 @@ class NcclTestBase : public ::testing::Test {
|
||||
|
||||
void RunGather() {
|
||||
VLOG(2) << "RunGather name " << parent_->collective_name_ << " rank "
|
||||
<< col_params_.default_rank;
|
||||
<< col_params_->default_rank;
|
||||
// Prepare an OpKernelContext.
|
||||
OpKernelContext::Params op_params;
|
||||
PrepareDeviceContext(&op_params);
|
||||
@ -383,13 +391,13 @@ class NcclTestBase : public ::testing::Test {
|
||||
// different shape.
|
||||
auto output_shape = input_.shape();
|
||||
output_shape.set_dim(
|
||||
0, output_shape.dim_size(0) * col_params_.group.group_size);
|
||||
0, output_shape.dim_size(0) * col_params_->group.group_size);
|
||||
output_ = Tensor(device_->GetAllocator(AllocatorAttributes()), DT_FLOAT,
|
||||
output_shape);
|
||||
|
||||
// Run gather.
|
||||
string exec_key =
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
strings::StrCat(col_params_->instance.instance_key, ":0:0");
|
||||
auto* gatherer = new NcclGatherer();
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->nccl_communicator_.get(),
|
||||
@ -415,7 +423,7 @@ class NcclTestBase : public ::testing::Test {
|
||||
Tensor input_;
|
||||
Tensor output_;
|
||||
Device* device_;
|
||||
CollectiveParams col_params_;
|
||||
CollectiveParams* col_params_;
|
||||
std::unique_ptr<OpKernel> merge_op_;
|
||||
std::unique_ptr<OpKernel> final_op_;
|
||||
Status status_;
|
||||
@ -430,7 +438,7 @@ class NcclTestBase : public ::testing::Test {
|
||||
CollectiveExecutor* col_exec_;
|
||||
std::unique_ptr<DeviceMgr> dev_mgr_;
|
||||
std::vector<std::unique_ptr<DeviceInstance>> instances_;
|
||||
CollectiveParams col_params_;
|
||||
CollectiveParams* col_params_;
|
||||
mutex mu_;
|
||||
int32 op_counter_ TF_GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
@ -463,8 +471,8 @@ class NcclReducerTest : public NcclTestBase {
|
||||
}
|
||||
|
||||
void InitDevice(DeviceInstance* di) override {
|
||||
di->col_params_.merge_op = di->merge_op_.get();
|
||||
di->col_params_.final_op = di->final_op_.get();
|
||||
di->col_params_->merge_op = di->merge_op_.get();
|
||||
di->col_params_->final_op = di->final_op_.get();
|
||||
}
|
||||
|
||||
void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunReduce(); }
|
||||
@ -493,8 +501,8 @@ class NcclBroadcasterTest : public NcclTestBase {
|
||||
}
|
||||
|
||||
void InitDevice(DeviceInstance* di) override {
|
||||
di->col_params_.source_rank = source_rank_;
|
||||
di->col_params_.is_source = di->col_params_.default_rank == source_rank_;
|
||||
di->col_params_->source_rank = source_rank_;
|
||||
di->col_params_->is_source = di->col_params_->default_rank == source_rank_;
|
||||
}
|
||||
|
||||
void RunCollectiveOnDevice(DeviceInstance* di) override {
|
||||
|
Loading…
x
Reference in New Issue
Block a user