Share ownership of CollectiveContext with kernels.

BaseCollectiveExecutor creates a CollectiveContext and passes a pointer to each
collective kernel implementation.  The CollectiveContext is deleted in the done
callback.  However, for some kernels like the NCCL reducer, it is possible that
the CollectiveContext is accessed after the NCCL CUDA kernel is enqueued on the
GPU stream.  This creates a race between the access and destruction.

This change changes CollectiveContext from a raw pointer to a shared pointer,
essentially sharing ownership of this object with the kernel.  Thus, even if
the done callback runs first, the kernel can still safely access the context.

Resolves #41113.

PiperOrigin-RevId: 321426944
Change-Id: I9f12fe403bf2cc0939006dbde38ec2985d75cfcd
This commit is contained in:
Ayush Dubey 2020-07-15 13:27:14 -07:00 committed by TensorFlower Gardener
parent ad38b201b3
commit 600b4f145b
13 changed files with 47 additions and 41 deletions

View File

@ -271,13 +271,12 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
DCHECK_EQ(nullptr, col_impl); DCHECK_EQ(nullptr, col_impl);
return; return;
} }
CollectiveContext* col_ctx = auto col_ctx = std::make_shared<CollectiveContext>(
new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params, this, dev_mgr_, ctx, CtxParams(ctx), col_params, exec_key, step_id_,
exec_key, step_id_, input, output); input, output);
status = col_impl->InitializeCollectiveContext(col_ctx); status = col_impl->InitializeCollectiveContext(col_ctx);
if (!status.ok()) { if (!status.ok()) {
done_safe(status); done_safe(status);
delete col_ctx;
delete col_impl; delete col_impl;
return; return;
} }
@ -293,7 +292,6 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
profiler::TraceMeLevel::kInfo); profiler::TraceMeLevel::kInfo);
col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) { col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
done_safe(s); done_safe(s);
delete col_ctx;
delete col_impl; delete col_impl;
}); });
}); });

View File

@ -186,7 +186,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
} }
Status HierarchicalTreeBroadcaster::InitializeCollectiveContext( Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
CollectiveContext* col_ctx) { std::shared_ptr<CollectiveContext> col_ctx) {
CHECK(col_ctx->dev_mgr); CHECK(col_ctx->dev_mgr);
col_ctx_ = col_ctx; col_ctx_ = col_ctx;
col_params_ = &col_ctx->col_params; col_params_ = &col_ctx->col_params;

View File

@ -39,7 +39,8 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
// Initializes members of CollectiveContext not yet initialized, i.e. device // Initializes members of CollectiveContext not yet initialized, i.e. device
// and device_locality. Also saves the CollectiveContext in this object. // and device_locality. Also saves the CollectiveContext in this object.
Status InitializeCollectiveContext(CollectiveContext* col_ctx) override; Status InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) override;
// No-op for hierarchical tree broadcaster. // No-op for hierarchical tree broadcaster.
Status InitializeCollectiveGroupRuntimeDetails( Status InitializeCollectiveGroupRuntimeDetails(
@ -80,7 +81,7 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
// Executes the hierarchical broadcast defined by this op. // Executes the hierarchical broadcast defined by this op.
void RunTree(); void RunTree();
CollectiveContext* col_ctx_; // Not owned std::shared_ptr<CollectiveContext> col_ctx_;
const CollectiveParams* col_params_; // Not owned const CollectiveParams* col_params_; // Not owned
StatusCallback done_; StatusCallback done_;
Status status_; Status status_;

View File

@ -670,10 +670,10 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
string exec_key = string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0"); strings::StrCat(col_params_.instance.instance_key, ":0:0");
HierarchicalTreeBroadcaster broadcaster; HierarchicalTreeBroadcaster broadcaster;
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), auto col_ctx = std::make_shared<CollectiveContext>(
&ctx, &op_params, col_params_, exec_key, parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
kStepId, input_tensor_ptr, output_tensor_ptr); col_params_, exec_key, kStepId, input_tensor_ptr, output_tensor_ptr);
TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx)); TF_CHECK_OK(broadcaster.InitializeCollectiveContext(col_ctx));
// Run the broadcast. // Run the broadcast.
broadcaster.Run([this](Status s) { status_ = s; }); broadcaster.Run([this](Status s) { status_ = s; });

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/ring_alg.h" #include "tensorflow/core/common_runtime/ring_alg.h"
#include <stdlib.h> #include <stdlib.h>
#include <atomic> #include <atomic>
#include <functional> #include <functional>
#include <utility> #include <utility>
@ -240,7 +241,8 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
return Status::OK(); return Status::OK();
} }
Status RingAlg::InitializeCollectiveContext(CollectiveContext* col_ctx) { Status RingAlg::InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) {
DCHECK(col_ctx->dev_mgr); DCHECK(col_ctx->dev_mgr);
col_ctx_ = col_ctx; col_ctx_ = col_ctx;
col_params_ = &col_ctx->col_params; col_params_ = &col_ctx->col_params;

View File

@ -39,7 +39,8 @@ class RingAlg : public CollectiveImplementationInterface {
// Initializes members of CollectiveContext not yet initialized, i.e. device // Initializes members of CollectiveContext not yet initialized, i.e. device
// and device_locality. Also saves the CollectiveContext in this object. // and device_locality. Also saves the CollectiveContext in this object.
Status InitializeCollectiveContext(CollectiveContext* col_ctx) override; Status InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) override;
// No-op for ring alg. // No-op for ring alg.
Status InitializeCollectiveGroupRuntimeDetails( Status InitializeCollectiveGroupRuntimeDetails(
@ -108,7 +109,7 @@ class RingAlg : public CollectiveImplementationInterface {
const CollectiveType type_; const CollectiveType type_;
const string name_; const string name_;
CollectiveContext* col_ctx_; // Not owned std::shared_ptr<CollectiveContext> col_ctx_;
const CollectiveParams* col_params_; // Not owned const CollectiveParams* col_params_; // Not owned
StatusCallback done_; StatusCallback done_;
int group_size_; int group_size_;

View File

@ -477,10 +477,10 @@ class RingGathererTest : public ::testing::Test {
string exec_key = string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0"); strings::StrCat(col_params_.instance.instance_key, ":0:0");
RingGatherer gatherer; RingGatherer gatherer;
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), auto col_ctx = std::make_shared<CollectiveContext>(
&ctx, &op_params, col_params_, exec_key, parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
kStepId, &input_tensor_, output_tensor_ptr); col_params_, exec_key, kStepId, &input_tensor_, output_tensor_ptr);
TF_CHECK_OK(gatherer.InitializeCollectiveContext(&col_ctx)); TF_CHECK_OK(gatherer.InitializeCollectiveContext(col_ctx));
// Run the all-gather. // Run the all-gather.
gatherer.Run([this](Status s) { status_ = s; }); gatherer.Run([this](Status s) { status_ = s; });

View File

@ -507,10 +507,10 @@ class RingReducerTest : public ::testing::Test {
string exec_key = string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0"); strings::StrCat(col_params_.instance.instance_key, ":0:0");
RingReducer reducer; RingReducer reducer;
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), auto col_ctx = std::make_shared<CollectiveContext>(
&ctx, &op_params, col_params_, exec_key, parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
kStepId, &tensor_, &tensor_); col_params_, exec_key, kStepId, &tensor_, &tensor_);
TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx)); TF_CHECK_OK(reducer.InitializeCollectiveContext(col_ctx));
// Run the all-reduce. // Run the all-reduce.
reducer.Run([this](Status s) { status_ = s; }); reducer.Run([this](Status s) { status_ = s; });

View File

@ -327,7 +327,8 @@ class MockNcclReducer : public CollectiveImplementationInterface {
Status InitializeCollectiveParams(CollectiveParams*) override { Status InitializeCollectiveParams(CollectiveParams*) override {
return Status::OK(); return Status::OK();
} }
Status InitializeCollectiveContext(CollectiveContext*) override { Status InitializeCollectiveContext(
std::shared_ptr<CollectiveContext>) override {
return Status::OK(); return Status::OK();
} }
Status InitializeCollectiveGroupRuntimeDetails( Status InitializeCollectiveGroupRuntimeDetails(

View File

@ -399,7 +399,8 @@ class CollectiveImplementationInterface {
// Called from CollectiveExecutor right before calling Run(). The // Called from CollectiveExecutor right before calling Run(). The
// CollectiveContext passed in must outlive the CollectiveImplementation // CollectiveContext passed in must outlive the CollectiveImplementation
// object. // object.
virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0; virtual Status InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) = 0;
// Performs collective implementation specific group initialization. The // Performs collective implementation specific group initialization. The
// intention is to do group-specific initialization of runtime details for the // intention is to do group-specific initialization of runtime details for the

View File

@ -58,7 +58,8 @@ Status NcclBase::InitializeCollectiveParams(CollectiveParams* col_params) {
return Status::OK(); return Status::OK();
} }
Status NcclBase::InitializeCollectiveContext(CollectiveContext* col_ctx) { Status NcclBase::InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) {
col_ctx_ = col_ctx; col_ctx_ = col_ctx;
col_params_ = &col_ctx->col_params; col_params_ = &col_ctx->col_params;
return collective_util::InitializeDeviceAndLocality( return collective_util::InitializeDeviceAndLocality(

View File

@ -29,7 +29,8 @@ class NcclBase : public CollectiveImplementationInterface {
Status InitializeCollectiveParams(CollectiveParams* col_params) override; Status InitializeCollectiveParams(CollectiveParams* col_params) override;
// Initializes the device objects and device localities. // Initializes the device objects and device localities.
Status InitializeCollectiveContext(CollectiveContext* col_ctx) override; Status InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) override;
// Initialize nccl communicator key. // Initialize nccl communicator key.
Status InitializeCollectiveGroupRuntimeDetails( Status InitializeCollectiveGroupRuntimeDetails(
@ -40,7 +41,7 @@ class NcclBase : public CollectiveImplementationInterface {
const CollectiveType type_; const CollectiveType type_;
const string name_; const string name_;
CollectiveContext* col_ctx_; // Not owned std::shared_ptr<CollectiveContext> col_ctx_;
const CollectiveParams* col_params_; // Not owned const CollectiveParams* col_params_; // Not owned
}; };

View File

@ -314,11 +314,11 @@ class NcclTestBase : public ::testing::Test {
string exec_key = string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0"); strings::StrCat(col_params_.instance.instance_key, ":0:0");
NcclReducer reducer; NcclReducer reducer;
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), auto col_ctx = std::make_shared<CollectiveContext>(
/*OpKernelContext=*/&ctx, &op_params, parent_->col_exec_, parent_->dev_mgr_.get(),
col_params_, exec_key, kStepId, /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
/*input=*/&input_, /*output=*/&input_); /*input=*/&input_, /*output=*/&input_);
TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx)); TF_CHECK_OK(reducer.InitializeCollectiveContext(col_ctx));
Notification note; Notification note;
reducer.Run([this, &note](Status s) { reducer.Run([this, &note](Status s) {
status_ = s; status_ = s;
@ -344,12 +344,12 @@ class NcclTestBase : public ::testing::Test {
string exec_key = string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0"); strings::StrCat(col_params_.instance.instance_key, ":0:0");
NcclBroadcaster broadcaster; NcclBroadcaster broadcaster;
CollectiveContext col_ctx( auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(), parent_->col_exec_, parent_->dev_mgr_.get(),
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId, /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
/*input=*/col_params_.is_source ? &input_ : nullptr, /*input=*/col_params_.is_source ? &input_ : nullptr,
/*output=*/&input_); /*output=*/&input_);
TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx)); TF_CHECK_OK(broadcaster.InitializeCollectiveContext(col_ctx));
Notification note; Notification note;
broadcaster.Run([this, &note](Status s) { broadcaster.Run([this, &note](Status s) {
status_ = s; status_ = s;
@ -383,12 +383,12 @@ class NcclTestBase : public ::testing::Test {
string exec_key = string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0"); strings::StrCat(col_params_.instance.instance_key, ":0:0");
NcclGatherer gatherer; NcclGatherer gatherer;
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), auto col_ctx = std::make_shared<CollectiveContext>(
/*OpKernelContext=*/&ctx, &op_params, parent_->col_exec_, parent_->dev_mgr_.get(),
col_params_, exec_key, kStepId, /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
/*input=*/&input_, /*input=*/&input_,
/*output=*/&output_); /*output=*/&output_);
TF_CHECK_OK(gatherer.InitializeCollectiveContext(&col_ctx)); TF_CHECK_OK(gatherer.InitializeCollectiveContext(col_ctx));
Notification note; Notification note;
gatherer.Run([this, &note](Status s) { gatherer.Run([this, &note](Status s) {
status_ = s; status_ = s;