Make V2 collective ComputeAsync thread safe
PiperOrigin-RevId: 338291666 Change-Id: I041f88b8b9f805f3660fe55b3ed0c5c04d7075d4
This commit is contained in:
parent
8e84212c47
commit
54bc354d2e
@ -256,7 +256,7 @@ bool RingReducer::RunAsyncParts() {
|
||||
rf->action = RF_REDUCE;
|
||||
Status s = collective_util::ComputeBinOp(
|
||||
col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
|
||||
col_params_->merge_op.get(), &rf->chunk, &rf->tmp_chunk);
|
||||
col_params_->merge_op, &rf->chunk, &rf->tmp_chunk);
|
||||
if (!s.ok()) {
|
||||
aborted = true;
|
||||
StartAbort(s);
|
||||
@ -266,13 +266,12 @@ bool RingReducer::RunAsyncParts() {
|
||||
}
|
||||
break;
|
||||
case RF_REDUCE:
|
||||
if (!rf->second_pass && col_params_->final_op.get() &&
|
||||
rf->is_final) {
|
||||
if (!rf->second_pass && col_params_->final_op && rf->is_final) {
|
||||
rf->action = RF_FINALIZE;
|
||||
group_size_tensor_ready_.WaitForNotification();
|
||||
Status s = collective_util::ComputeBinOp(
|
||||
col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
|
||||
col_params_->final_op.get(), &rf->chunk, &group_size_tensor_);
|
||||
col_params_->final_op, &rf->chunk, &group_size_tensor_);
|
||||
if (!s.ok()) {
|
||||
aborted = true;
|
||||
StartAbort(s);
|
||||
|
@ -466,10 +466,10 @@ class RingReducerTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
void DoReduce() {
|
||||
col_params_.merge_op =
|
||||
GetAdd(col_params_.instance.data_type, device_type_, device_);
|
||||
col_params_.final_op =
|
||||
GetDiv(col_params_.instance.data_type, device_type_, device_);
|
||||
merge_op_ = GetAdd(col_params_.instance.data_type, device_type_, device_);
|
||||
final_op_ = GetDiv(col_params_.instance.data_type, device_type_, device_);
|
||||
col_params_.merge_op = merge_op_.get();
|
||||
col_params_.final_op = final_op_.get();
|
||||
|
||||
// Prepare an OpKernelContext.
|
||||
OpKernelContext::Params op_params;
|
||||
@ -536,6 +536,8 @@ class RingReducerTest : public ::testing::Test {
|
||||
Tensor tensor_;
|
||||
Device* device_;
|
||||
CollectiveParams col_params_;
|
||||
std::unique_ptr<OpKernel> merge_op_;
|
||||
std::unique_ptr<OpKernel> final_op_;
|
||||
std::unique_ptr<CollectiveAdapter> ca_;
|
||||
std::unique_ptr<OpKernelContext> ctx_;
|
||||
Status status_;
|
||||
|
@ -143,8 +143,8 @@ struct CollectiveParams {
|
||||
int source_rank = -1; // broadcast only
|
||||
// Rank of this device in each subdivision permutation.
|
||||
std::vector<int> subdiv_rank;
|
||||
std::unique_ptr<OpKernel> merge_op; // reduction only
|
||||
std::unique_ptr<OpKernel> final_op; // reduction only
|
||||
OpKernel* merge_op = nullptr; // reduction only
|
||||
OpKernel* final_op = nullptr; // reduction only
|
||||
string ToString() const;
|
||||
};
|
||||
|
||||
|
@ -113,7 +113,7 @@ void NcclReducer::Run(StatusCallback done) {
|
||||
if (final_status.ok()) {
|
||||
final_status = collective_util::ComputeBinOp(
|
||||
col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
|
||||
col_params_->final_op.get(), col_ctx_->output, &group_size);
|
||||
col_params_->final_op, col_ctx_->output, &group_size);
|
||||
}
|
||||
done(final_status);
|
||||
}
|
||||
|
@ -248,6 +248,8 @@ class NcclTestBase : public ::testing::Test {
|
||||
TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(device_name_, &device_))
|
||||
<< "Could not find device " << device_name_ << " existing devices "
|
||||
<< parent_->dev_mgr_->DebugString();
|
||||
merge_op_ = GetAdd(device_);
|
||||
final_op_ = GetDiv(device_);
|
||||
col_params_.name = parent_->col_params_.name;
|
||||
col_params_.default_rank = rank;
|
||||
col_params_.group = parent_->col_params_.group;
|
||||
@ -414,6 +416,8 @@ class NcclTestBase : public ::testing::Test {
|
||||
Tensor output_;
|
||||
Device* device_;
|
||||
CollectiveParams col_params_;
|
||||
std::unique_ptr<OpKernel> merge_op_;
|
||||
std::unique_ptr<OpKernel> final_op_;
|
||||
Status status_;
|
||||
};
|
||||
|
||||
@ -459,8 +463,8 @@ class NcclReducerTest : public NcclTestBase {
|
||||
}
|
||||
|
||||
void InitDevice(DeviceInstance* di) override {
|
||||
di->col_params_.merge_op = GetAdd(di->device_);
|
||||
di->col_params_.final_op = GetDiv(di->device_);
|
||||
di->col_params_.merge_op = di->merge_op_.get();
|
||||
di->col_params_.final_op = di->final_op_.get();
|
||||
}
|
||||
|
||||
void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunReduce(); }
|
||||
|
@ -271,8 +271,10 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
|
||||
sub_node.set_device(real_node.device());
|
||||
SetAttrValue(col_params_.instance.data_type,
|
||||
&(*sub_node.mutable_attr())["T"]);
|
||||
col_params_.merge_op = BuildOpKernel(c, merge_op_name, &sub_node);
|
||||
col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
|
||||
merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
|
||||
final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
|
||||
col_params_.merge_op = merge_op_.get();
|
||||
col_params_.final_op = final_op_.get();
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -311,6 +313,8 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<OpKernel> merge_op_;
|
||||
std::unique_ptr<OpKernel> final_op_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel);
|
||||
};
|
||||
|
||||
@ -469,9 +473,8 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
|
||||
class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
|
||||
public:
|
||||
explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {
|
||||
col_params_ = std::make_shared<CollectiveParams>();
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
|
||||
: CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
|
||||
string merge_op_name;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
|
||||
@ -482,32 +485,23 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
|
||||
}
|
||||
string final_op_name;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
|
||||
OP_REQUIRES_OK(
|
||||
c, c->GetAttr("communication_hint",
|
||||
&col_params_->instance.impl_details.communication_hint));
|
||||
OP_REQUIRES_OK(
|
||||
c, c->GetAttr("timeout_seconds",
|
||||
&col_params_->instance.impl_details.timeout_seconds));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
|
||||
// Prepare OpKernels for reduction and final operations.
|
||||
// The merge_op takes two inputs
|
||||
NodeDef sub_node;
|
||||
sub_node.add_input(c->def().input(0));
|
||||
sub_node.add_input(c->def().input(0));
|
||||
sub_node.set_device(c->def().device());
|
||||
SetAttrValue(col_params_->instance.data_type,
|
||||
&(*sub_node.mutable_attr())["T"]);
|
||||
col_params_->merge_op = BuildOpKernel(c, merge_op_name, &sub_node);
|
||||
col_params_->final_op = BuildOpKernel(c, final_op_name, &sub_node);
|
||||
SetAttrValue(data_type_, &(*sub_node.mutable_attr())["T"]);
|
||||
merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
|
||||
final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
|
||||
|
||||
col_params_->name = strings::StrCat(c->def().name(), ": ReduceV2(",
|
||||
merge_op_name, ",", final_op_name, ")");
|
||||
col_params_->group.device_type = c->device_type();
|
||||
// Add a default value for subdiv offsets, which is the same as the default
|
||||
// value in the V1 op's attribute.
|
||||
col_params_->instance.impl_details.subdiv_offsets.push_back(0);
|
||||
VLOG(2) << "CollectiveReduceV2 " << this << " name " << col_params_->name
|
||||
<< " communication_hint "
|
||||
<< col_params_->instance.impl_details.communication_hint;
|
||||
name_ = strings::StrCat(c->def().name(), ": ReduceV2(", merge_op_name, ",",
|
||||
final_op_name, ")");
|
||||
device_type_ = c->device_type();
|
||||
VLOG(2) << "CollectiveReduceV2 " << this << " name " << name_
|
||||
<< " communication_hint " << communication_hint_;
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -527,48 +521,49 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
|
||||
c, instance_key.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input instance_key"), done);
|
||||
|
||||
auto col_params = std::make_shared<CollectiveParams>();
|
||||
col_params->name = col_params_->name;
|
||||
col_params->group.device_type = col_params_->group.device_type;
|
||||
auto col_params = new CollectiveParams();
|
||||
col_params->name = name_;
|
||||
col_params->group.device_type = device_type_;
|
||||
col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
|
||||
col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
|
||||
col_params->instance.type = REDUCTION_COLLECTIVE;
|
||||
col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
|
||||
col_params->instance.data_type = col_params_->instance.data_type;
|
||||
col_params->instance.impl_details.communication_hint =
|
||||
col_params_->instance.impl_details.communication_hint;
|
||||
col_params->instance.impl_details.timeout_seconds =
|
||||
col_params_->instance.impl_details.timeout_seconds;
|
||||
col_params->instance.impl_details.subdiv_offsets =
|
||||
col_params_->instance.impl_details.subdiv_offsets;
|
||||
col_params->merge_op = std::move(col_params_->merge_op);
|
||||
col_params->final_op = std::move(col_params_->final_op);
|
||||
col_params->instance.data_type = data_type_;
|
||||
col_params->instance.impl_details.communication_hint = communication_hint_;
|
||||
col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
|
||||
// Add a default value for subdiv offsets, which is the same as the default
|
||||
// value in the V1 op's attribute.
|
||||
col_params->instance.impl_details.subdiv_offsets.push_back(0);
|
||||
col_params->merge_op = merge_op_.get();
|
||||
col_params->final_op = final_op_.get();
|
||||
VLOG(1) << "CollectiveReduceV2 group_size " << col_params->group.group_size
|
||||
<< " group_key " << col_params->group.group_key << " instance_key "
|
||||
<< col_params->instance.instance_key;
|
||||
|
||||
auto done_with_cleanup = [col_params, done = std::move(done)]() {
|
||||
delete col_params;
|
||||
done();
|
||||
};
|
||||
|
||||
// Allocate the output tensor, trying to reuse the input.
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
c, c->forward_input_or_allocate_output({0}, 0, input.shape(), &output),
|
||||
done);
|
||||
done_with_cleanup);
|
||||
col_params->instance.shape = input.shape();
|
||||
|
||||
// Store the updated params in this OpKernel.
|
||||
col_params_ = col_params;
|
||||
|
||||
// Resolve the collective params.
|
||||
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
|
||||
// blocking work because it's not guaranteed that this call cannot block.
|
||||
c->collective_executor()->RunClosure([c, done = std::move(done), col_params,
|
||||
col_exec]() {
|
||||
c->collective_executor()->RunClosure([c,
|
||||
done = std::move(done_with_cleanup),
|
||||
col_params, col_exec]() {
|
||||
VLOG(1) << "CollectiveReduceV2 CompleteParams for collective "
|
||||
<< col_params->name << " device " << c->device()->name()
|
||||
<< " group " << col_params->group.group_key << " instance "
|
||||
<< col_params->instance.instance_key;
|
||||
col_exec->CompleteParamsAsync(
|
||||
c->device()->attributes(), col_params.get(),
|
||||
c->cancellation_manager(),
|
||||
c->device()->attributes(), col_params, c->cancellation_manager(),
|
||||
[c, done = std::move(done), col_params, col_exec](const Status& s) {
|
||||
if (s.ok()) {
|
||||
auto actual_done = [c, group_key = col_params->group.group_key,
|
||||
@ -602,7 +597,12 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<CollectiveParams> col_params_;
|
||||
DataType data_type_ = DT_INVALID;
|
||||
string communication_hint_;
|
||||
float timeout_seconds_ = 0;
|
||||
DeviceType device_type_;
|
||||
std::unique_ptr<OpKernel> merge_op_;
|
||||
std::unique_ptr<OpKernel> final_op_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2").Device(DEVICE_CPU),
|
||||
@ -673,15 +673,16 @@ class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
|
||||
0, output_shape.dim_size(0) * col_params->group.group_size);
|
||||
col_params->instance.shape = output_shape;
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
c, c->allocate_output(0, col_params->instance.shape, &output), done);
|
||||
|
||||
auto done_with_cleanup = [col_params, done = std::move(done)]() {
|
||||
delete col_params;
|
||||
done();
|
||||
};
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
c, c->allocate_output(0, col_params->instance.shape, &output),
|
||||
done_with_cleanup);
|
||||
|
||||
// Resolve the collective params.
|
||||
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
|
||||
// blocking work because it's not guaranteed that this call cannot block.
|
||||
@ -727,9 +728,9 @@ class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
DataType data_type_;
|
||||
DataType data_type_ = DT_INVALID;
|
||||
string communication_hint_;
|
||||
float timeout_seconds_;
|
||||
float timeout_seconds_ = 0;
|
||||
DeviceType device_type_;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user