Make V2 collective ComputeAsync thread safe

PiperOrigin-RevId: 338291666
Change-Id: I041f88b8b9f805f3660fe55b3ed0c5c04d7075d4
This commit is contained in:
Ran Chen 2020-10-21 10:33:47 -07:00 committed by TensorFlower Gardener
parent 8e84212c47
commit 54bc354d2e
6 changed files with 70 additions and 64 deletions

View File

@ -256,7 +256,7 @@ bool RingReducer::RunAsyncParts() {
rf->action = RF_REDUCE;
Status s = collective_util::ComputeBinOp(
col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
col_params_->merge_op.get(), &rf->chunk, &rf->tmp_chunk);
col_params_->merge_op, &rf->chunk, &rf->tmp_chunk);
if (!s.ok()) {
aborted = true;
StartAbort(s);
@ -266,13 +266,12 @@ bool RingReducer::RunAsyncParts() {
}
break;
case RF_REDUCE:
if (!rf->second_pass && col_params_->final_op.get() &&
rf->is_final) {
if (!rf->second_pass && col_params_->final_op && rf->is_final) {
rf->action = RF_FINALIZE;
group_size_tensor_ready_.WaitForNotification();
Status s = collective_util::ComputeBinOp(
col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
col_params_->final_op.get(), &rf->chunk, &group_size_tensor_);
col_params_->final_op, &rf->chunk, &group_size_tensor_);
if (!s.ok()) {
aborted = true;
StartAbort(s);

View File

@ -466,10 +466,10 @@ class RingReducerTest : public ::testing::Test {
}
void DoReduce() {
col_params_.merge_op =
GetAdd(col_params_.instance.data_type, device_type_, device_);
col_params_.final_op =
GetDiv(col_params_.instance.data_type, device_type_, device_);
merge_op_ = GetAdd(col_params_.instance.data_type, device_type_, device_);
final_op_ = GetDiv(col_params_.instance.data_type, device_type_, device_);
col_params_.merge_op = merge_op_.get();
col_params_.final_op = final_op_.get();
// Prepare an OpKernelContext.
OpKernelContext::Params op_params;
@ -536,6 +536,8 @@ class RingReducerTest : public ::testing::Test {
Tensor tensor_;
Device* device_;
CollectiveParams col_params_;
std::unique_ptr<OpKernel> merge_op_;
std::unique_ptr<OpKernel> final_op_;
std::unique_ptr<CollectiveAdapter> ca_;
std::unique_ptr<OpKernelContext> ctx_;
Status status_;

View File

@ -143,8 +143,8 @@ struct CollectiveParams {
int source_rank = -1; // broadcast only
// Rank of this device in each subdivision permutation.
std::vector<int> subdiv_rank;
std::unique_ptr<OpKernel> merge_op; // reduction only
std::unique_ptr<OpKernel> final_op; // reduction only
OpKernel* merge_op = nullptr; // reduction only
OpKernel* final_op = nullptr; // reduction only
string ToString() const;
};

View File

@ -113,7 +113,7 @@ void NcclReducer::Run(StatusCallback done) {
if (final_status.ok()) {
final_status = collective_util::ComputeBinOp(
col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
col_params_->final_op.get(), col_ctx_->output, &group_size);
col_params_->final_op, col_ctx_->output, &group_size);
}
done(final_status);
}

View File

@ -248,6 +248,8 @@ class NcclTestBase : public ::testing::Test {
TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(device_name_, &device_))
<< "Could not find device " << device_name_ << " existing devices "
<< parent_->dev_mgr_->DebugString();
merge_op_ = GetAdd(device_);
final_op_ = GetDiv(device_);
col_params_.name = parent_->col_params_.name;
col_params_.default_rank = rank;
col_params_.group = parent_->col_params_.group;
@ -414,6 +416,8 @@ class NcclTestBase : public ::testing::Test {
Tensor output_;
Device* device_;
CollectiveParams col_params_;
std::unique_ptr<OpKernel> merge_op_;
std::unique_ptr<OpKernel> final_op_;
Status status_;
};
@ -459,8 +463,8 @@ class NcclReducerTest : public NcclTestBase {
}
void InitDevice(DeviceInstance* di) override {
di->col_params_.merge_op = GetAdd(di->device_);
di->col_params_.final_op = GetDiv(di->device_);
di->col_params_.merge_op = di->merge_op_.get();
di->col_params_.final_op = di->final_op_.get();
}
void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunReduce(); }

View File

@ -271,8 +271,10 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
sub_node.set_device(real_node.device());
SetAttrValue(col_params_.instance.data_type,
&(*sub_node.mutable_attr())["T"]);
col_params_.merge_op = BuildOpKernel(c, merge_op_name, &sub_node);
col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
col_params_.merge_op = merge_op_.get();
col_params_.final_op = final_op_.get();
}
protected:
@ -311,6 +313,8 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
}
private:
std::unique_ptr<OpKernel> merge_op_;
std::unique_ptr<OpKernel> final_op_;
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel);
};
@ -469,9 +473,8 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
public:
explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
col_params_ = std::make_shared<CollectiveParams>();
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
: CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) {
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
string merge_op_name;
OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
@ -482,32 +485,23 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
}
string final_op_name;
OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_->instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_->instance.impl_details.timeout_seconds));
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
// Prepare OpKernels for reduction and final operations.
// The merge_op takes two inputs
NodeDef sub_node;
sub_node.add_input(c->def().input(0));
sub_node.add_input(c->def().input(0));
sub_node.set_device(c->def().device());
SetAttrValue(col_params_->instance.data_type,
&(*sub_node.mutable_attr())["T"]);
col_params_->merge_op = BuildOpKernel(c, merge_op_name, &sub_node);
col_params_->final_op = BuildOpKernel(c, final_op_name, &sub_node);
SetAttrValue(data_type_, &(*sub_node.mutable_attr())["T"]);
merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
col_params_->name = strings::StrCat(c->def().name(), ": ReduceV2(",
merge_op_name, ",", final_op_name, ")");
col_params_->group.device_type = c->device_type();
// Add a default value for subdiv offsets, which is the same as the default
// value in the V1 op's attribute.
col_params_->instance.impl_details.subdiv_offsets.push_back(0);
VLOG(2) << "CollectiveReduceV2 " << this << " name " << col_params_->name
<< " communication_hint "
<< col_params_->instance.impl_details.communication_hint;
name_ = strings::StrCat(c->def().name(), ": ReduceV2(", merge_op_name, ",",
final_op_name, ")");
device_type_ = c->device_type();
VLOG(2) << "CollectiveReduceV2 " << this << " name " << name_
<< " communication_hint " << communication_hint_;
}
protected:
@ -527,48 +521,49 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
c, instance_key.dims() == 0,
errors::Internal("Unexpected dimensions on input instance_key"), done);
auto col_params = std::make_shared<CollectiveParams>();
col_params->name = col_params_->name;
col_params->group.device_type = col_params_->group.device_type;
auto col_params = new CollectiveParams();
col_params->name = name_;
col_params->group.device_type = device_type_;
col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
col_params->instance.type = REDUCTION_COLLECTIVE;
col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
col_params->instance.data_type = col_params_->instance.data_type;
col_params->instance.impl_details.communication_hint =
col_params_->instance.impl_details.communication_hint;
col_params->instance.impl_details.timeout_seconds =
col_params_->instance.impl_details.timeout_seconds;
col_params->instance.impl_details.subdiv_offsets =
col_params_->instance.impl_details.subdiv_offsets;
col_params->merge_op = std::move(col_params_->merge_op);
col_params->final_op = std::move(col_params_->final_op);
col_params->instance.data_type = data_type_;
col_params->instance.impl_details.communication_hint = communication_hint_;
col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
// Add a default value for subdiv offsets, which is the same as the default
// value in the V1 op's attribute.
col_params->instance.impl_details.subdiv_offsets.push_back(0);
col_params->merge_op = merge_op_.get();
col_params->final_op = final_op_.get();
VLOG(1) << "CollectiveReduceV2 group_size " << col_params->group.group_size
<< " group_key " << col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key;
auto done_with_cleanup = [col_params, done = std::move(done)]() {
delete col_params;
done();
};
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->forward_input_or_allocate_output({0}, 0, input.shape(), &output),
done);
done_with_cleanup);
col_params->instance.shape = input.shape();
// Store the updated params in this OpKernel.
col_params_ = col_params;
// Resolve the collective params.
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
// blocking work because it's not guaranteed that this call cannot block.
c->collective_executor()->RunClosure([c, done = std::move(done), col_params,
col_exec]() {
c->collective_executor()->RunClosure([c,
done = std::move(done_with_cleanup),
col_params, col_exec]() {
VLOG(1) << "CollectiveReduceV2 CompleteParams for collective "
<< col_params->name << " device " << c->device()->name()
<< " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->attributes(), col_params.get(),
c->cancellation_manager(),
c->device()->attributes(), col_params, c->cancellation_manager(),
[c, done = std::move(done), col_params, col_exec](const Status& s) {
if (s.ok()) {
auto actual_done = [c, group_key = col_params->group.group_key,
@ -602,7 +597,12 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
}
private:
std::shared_ptr<CollectiveParams> col_params_;
DataType data_type_ = DT_INVALID;
string communication_hint_;
float timeout_seconds_ = 0;
DeviceType device_type_;
std::unique_ptr<OpKernel> merge_op_;
std::unique_ptr<OpKernel> final_op_;
};
REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2").Device(DEVICE_CPU),
@ -673,15 +673,16 @@ class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
0, output_shape.dim_size(0) * col_params->group.group_size);
col_params->instance.shape = output_shape;
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->allocate_output(0, col_params->instance.shape, &output), done);
auto done_with_cleanup = [col_params, done = std::move(done)]() {
delete col_params;
done();
};
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->allocate_output(0, col_params->instance.shape, &output),
done_with_cleanup);
// Resolve the collective params.
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
// blocking work because it's not guaranteed that this call cannot block.
@ -727,9 +728,9 @@ class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
}
private:
DataType data_type_;
DataType data_type_ = DT_INVALID;
string communication_hint_;
float timeout_seconds_;
float timeout_seconds_ = 0;
DeviceType device_type_;
};