Share ownership of CollectiveContext with kernels.

BaseCollectiveExecutor creates a CollectiveContext and passes a pointer to each
collective kernel implementation.  The CollectiveContext is deleted in the done
callback.  However, for some kernels like the NCCL reducer, it is possible that
the CollectiveContext is accessed after the NCCL CUDA kernel is enqueued on the
GPU stream.  This creates a race between the access and destruction.

This change changes CollectiveContext from a raw pointer to a shared pointer,
essentially sharing ownership of this object with the kernel.  Thus, even if
the done callback runs first, the kernel can still safely access the context.

Resolves #41113.

PiperOrigin-RevId: 321426944
Change-Id: I9f12fe403bf2cc0939006dbde38ec2985d75cfcd
This commit is contained in:
Ayush Dubey 2020-07-15 13:27:14 -07:00 committed by TensorFlower Gardener
parent ad38b201b3
commit 600b4f145b
13 changed files with 47 additions and 41 deletions

View File

@ -271,13 +271,12 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
DCHECK_EQ(nullptr, col_impl);
return;
}
CollectiveContext* col_ctx =
new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params,
exec_key, step_id_, input, output);
auto col_ctx = std::make_shared<CollectiveContext>(
this, dev_mgr_, ctx, CtxParams(ctx), col_params, exec_key, step_id_,
input, output);
status = col_impl->InitializeCollectiveContext(col_ctx);
if (!status.ok()) {
done_safe(status);
delete col_ctx;
delete col_impl;
return;
}
@ -293,7 +292,6 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
profiler::TraceMeLevel::kInfo);
col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
done_safe(s);
delete col_ctx;
delete col_impl;
});
});

View File

@ -186,7 +186,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
}
Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
CollectiveContext* col_ctx) {
std::shared_ptr<CollectiveContext> col_ctx) {
CHECK(col_ctx->dev_mgr);
col_ctx_ = col_ctx;
col_params_ = &col_ctx->col_params;

View File

@ -39,7 +39,8 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
// Initializes members of CollectiveContext not yet initialized, i.e. device
// and device_locality. Also saves the CollectiveContext in this object.
Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
Status InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) override;
// No-op for hierarchical tree broadcaster.
Status InitializeCollectiveGroupRuntimeDetails(
@ -80,7 +81,7 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
// Executes the hierarchical broadcast defined by this op.
void RunTree();
CollectiveContext* col_ctx_; // Not owned
std::shared_ptr<CollectiveContext> col_ctx_;
const CollectiveParams* col_params_; // Not owned
StatusCallback done_;
Status status_;

View File

@ -670,10 +670,10 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
HierarchicalTreeBroadcaster broadcaster;
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
&ctx, &op_params, col_params_, exec_key,
kStepId, input_tensor_ptr, output_tensor_ptr);
TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx));
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
col_params_, exec_key, kStepId, input_tensor_ptr, output_tensor_ptr);
TF_CHECK_OK(broadcaster.InitializeCollectiveContext(col_ctx));
// Run the broadcast.
broadcaster.Run([this](Status s) { status_ = s; });

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/ring_alg.h"
#include <stdlib.h>
#include <atomic>
#include <functional>
#include <utility>
@ -240,7 +241,8 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
return Status::OK();
}
Status RingAlg::InitializeCollectiveContext(CollectiveContext* col_ctx) {
Status RingAlg::InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) {
DCHECK(col_ctx->dev_mgr);
col_ctx_ = col_ctx;
col_params_ = &col_ctx->col_params;

View File

@ -39,7 +39,8 @@ class RingAlg : public CollectiveImplementationInterface {
// Initializes members of CollectiveContext not yet initialized, i.e. device
// and device_locality. Also saves the CollectiveContext in this object.
Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
Status InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) override;
// No-op for ring alg.
Status InitializeCollectiveGroupRuntimeDetails(
@ -108,7 +109,7 @@ class RingAlg : public CollectiveImplementationInterface {
const CollectiveType type_;
const string name_;
CollectiveContext* col_ctx_; // Not owned
std::shared_ptr<CollectiveContext> col_ctx_;
const CollectiveParams* col_params_; // Not owned
StatusCallback done_;
int group_size_;

View File

@ -477,10 +477,10 @@ class RingGathererTest : public ::testing::Test {
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
RingGatherer gatherer;
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
&ctx, &op_params, col_params_, exec_key,
kStepId, &input_tensor_, output_tensor_ptr);
TF_CHECK_OK(gatherer.InitializeCollectiveContext(&col_ctx));
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
col_params_, exec_key, kStepId, &input_tensor_, output_tensor_ptr);
TF_CHECK_OK(gatherer.InitializeCollectiveContext(col_ctx));
// Run the all-gather.
gatherer.Run([this](Status s) { status_ = s; });

View File

@ -507,10 +507,10 @@ class RingReducerTest : public ::testing::Test {
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
RingReducer reducer;
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
&ctx, &op_params, col_params_, exec_key,
kStepId, &tensor_, &tensor_);
TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
col_params_, exec_key, kStepId, &tensor_, &tensor_);
TF_CHECK_OK(reducer.InitializeCollectiveContext(col_ctx));
// Run the all-reduce.
reducer.Run([this](Status s) { status_ = s; });

View File

@ -327,7 +327,8 @@ class MockNcclReducer : public CollectiveImplementationInterface {
Status InitializeCollectiveParams(CollectiveParams*) override {
return Status::OK();
}
Status InitializeCollectiveContext(CollectiveContext*) override {
Status InitializeCollectiveContext(
std::shared_ptr<CollectiveContext>) override {
return Status::OK();
}
Status InitializeCollectiveGroupRuntimeDetails(

View File

@ -399,7 +399,8 @@ class CollectiveImplementationInterface {
// Called from CollectiveExecutor right before calling Run(). The
// CollectiveContext passed in must outlive the CollectiveImplementation
// object.
virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0;
virtual Status InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) = 0;
// Performs collective implementation specific group initialization. The
// intention is to do group-specific initialization of runtime details for the

View File

@ -58,7 +58,8 @@ Status NcclBase::InitializeCollectiveParams(CollectiveParams* col_params) {
return Status::OK();
}
Status NcclBase::InitializeCollectiveContext(CollectiveContext* col_ctx) {
Status NcclBase::InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) {
col_ctx_ = col_ctx;
col_params_ = &col_ctx->col_params;
return collective_util::InitializeDeviceAndLocality(

View File

@ -29,7 +29,8 @@ class NcclBase : public CollectiveImplementationInterface {
Status InitializeCollectiveParams(CollectiveParams* col_params) override;
// Initializes the device objects and device localities.
Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
Status InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) override;
// Initialize nccl communicator key.
Status InitializeCollectiveGroupRuntimeDetails(
@ -40,7 +41,7 @@ class NcclBase : public CollectiveImplementationInterface {
const CollectiveType type_;
const string name_;
CollectiveContext* col_ctx_; // Not owned
std::shared_ptr<CollectiveContext> col_ctx_;
const CollectiveParams* col_params_; // Not owned
};

View File

@ -314,11 +314,11 @@ class NcclTestBase : public ::testing::Test {
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
NcclReducer reducer;
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
/*OpKernelContext=*/&ctx, &op_params,
col_params_, exec_key, kStepId,
/*input=*/&input_, /*output=*/&input_);
TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(),
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
/*input=*/&input_, /*output=*/&input_);
TF_CHECK_OK(reducer.InitializeCollectiveContext(col_ctx));
Notification note;
reducer.Run([this, &note](Status s) {
status_ = s;
@ -344,12 +344,12 @@ class NcclTestBase : public ::testing::Test {
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
NcclBroadcaster broadcaster;
CollectiveContext col_ctx(
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(),
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
/*input=*/col_params_.is_source ? &input_ : nullptr,
/*output=*/&input_);
TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx));
TF_CHECK_OK(broadcaster.InitializeCollectiveContext(col_ctx));
Notification note;
broadcaster.Run([this, &note](Status s) {
status_ = s;
@ -383,12 +383,12 @@ class NcclTestBase : public ::testing::Test {
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
NcclGatherer gatherer;
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
/*OpKernelContext=*/&ctx, &op_params,
col_params_, exec_key, kStepId,
/*input=*/&input_,
/*output=*/&output_);
TF_CHECK_OK(gatherer.InitializeCollectiveContext(&col_ctx));
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(),
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
/*input=*/&input_,
/*output=*/&output_);
TF_CHECK_OK(gatherer.InitializeCollectiveContext(col_ctx));
Notification note;
gatherer.Run([this, &note](Status s) {
status_ = s;