Share ownership of CollectiveContext with kernels.
BaseCollectiveExecutor creates a CollectiveContext and passes a pointer to each collective kernel implementation. The CollectiveContext is deleted in the done callback. However, for some kernels like the NCCL reducer, it is possible that the CollectiveContext is accessed after the NCCL CUDA kernel is enqueued on the GPU stream. This creates a race between the access and destruction. This change changes CollectiveContext from a raw pointer to a shared pointer, essentially sharing ownership of this object with the kernel. Thus, even if the done callback runs first, the kernel can still safely access the context. Resolves #41113. PiperOrigin-RevId: 321426944 Change-Id: I9f12fe403bf2cc0939006dbde38ec2985d75cfcd
This commit is contained in:
parent
ad38b201b3
commit
600b4f145b
@ -271,13 +271,12 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
|
||||
DCHECK_EQ(nullptr, col_impl);
|
||||
return;
|
||||
}
|
||||
CollectiveContext* col_ctx =
|
||||
new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params,
|
||||
exec_key, step_id_, input, output);
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
this, dev_mgr_, ctx, CtxParams(ctx), col_params, exec_key, step_id_,
|
||||
input, output);
|
||||
status = col_impl->InitializeCollectiveContext(col_ctx);
|
||||
if (!status.ok()) {
|
||||
done_safe(status);
|
||||
delete col_ctx;
|
||||
delete col_impl;
|
||||
return;
|
||||
}
|
||||
@ -293,7 +292,6 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
|
||||
done_safe(s);
|
||||
delete col_ctx;
|
||||
delete col_impl;
|
||||
});
|
||||
});
|
||||
|
@ -186,7 +186,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
|
||||
}
|
||||
|
||||
Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
|
||||
CollectiveContext* col_ctx) {
|
||||
std::shared_ptr<CollectiveContext> col_ctx) {
|
||||
CHECK(col_ctx->dev_mgr);
|
||||
col_ctx_ = col_ctx;
|
||||
col_params_ = &col_ctx->col_params;
|
||||
|
@ -39,7 +39,8 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
|
||||
|
||||
// Initializes members of CollectiveContext not yet initialized, i.e. device
|
||||
// and device_locality. Also saves the CollectiveContext in this object.
|
||||
Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
|
||||
Status InitializeCollectiveContext(
|
||||
std::shared_ptr<CollectiveContext> col_ctx) override;
|
||||
|
||||
// No-op for hierarchical tree broadcaster.
|
||||
Status InitializeCollectiveGroupRuntimeDetails(
|
||||
@ -80,7 +81,7 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
|
||||
// Executes the hierarchical broadcast defined by this op.
|
||||
void RunTree();
|
||||
|
||||
CollectiveContext* col_ctx_; // Not owned
|
||||
std::shared_ptr<CollectiveContext> col_ctx_;
|
||||
const CollectiveParams* col_params_; // Not owned
|
||||
StatusCallback done_;
|
||||
Status status_;
|
||||
|
@ -670,10 +670,10 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
string exec_key =
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
HierarchicalTreeBroadcaster broadcaster;
|
||||
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
&ctx, &op_params, col_params_, exec_key,
|
||||
kStepId, input_tensor_ptr, output_tensor_ptr);
|
||||
TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx));
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
|
||||
col_params_, exec_key, kStepId, input_tensor_ptr, output_tensor_ptr);
|
||||
TF_CHECK_OK(broadcaster.InitializeCollectiveContext(col_ctx));
|
||||
|
||||
// Run the broadcast.
|
||||
broadcaster.Run([this](Status s) { status_ = s; });
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/ring_alg.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
@ -240,7 +241,8 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RingAlg::InitializeCollectiveContext(CollectiveContext* col_ctx) {
|
||||
Status RingAlg::InitializeCollectiveContext(
|
||||
std::shared_ptr<CollectiveContext> col_ctx) {
|
||||
DCHECK(col_ctx->dev_mgr);
|
||||
col_ctx_ = col_ctx;
|
||||
col_params_ = &col_ctx->col_params;
|
||||
|
@ -39,7 +39,8 @@ class RingAlg : public CollectiveImplementationInterface {
|
||||
|
||||
// Initializes members of CollectiveContext not yet initialized, i.e. device
|
||||
// and device_locality. Also saves the CollectiveContext in this object.
|
||||
Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
|
||||
Status InitializeCollectiveContext(
|
||||
std::shared_ptr<CollectiveContext> col_ctx) override;
|
||||
|
||||
// No-op for ring alg.
|
||||
Status InitializeCollectiveGroupRuntimeDetails(
|
||||
@ -108,7 +109,7 @@ class RingAlg : public CollectiveImplementationInterface {
|
||||
|
||||
const CollectiveType type_;
|
||||
const string name_;
|
||||
CollectiveContext* col_ctx_; // Not owned
|
||||
std::shared_ptr<CollectiveContext> col_ctx_;
|
||||
const CollectiveParams* col_params_; // Not owned
|
||||
StatusCallback done_;
|
||||
int group_size_;
|
||||
|
@ -477,10 +477,10 @@ class RingGathererTest : public ::testing::Test {
|
||||
string exec_key =
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
RingGatherer gatherer;
|
||||
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
&ctx, &op_params, col_params_, exec_key,
|
||||
kStepId, &input_tensor_, output_tensor_ptr);
|
||||
TF_CHECK_OK(gatherer.InitializeCollectiveContext(&col_ctx));
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
|
||||
col_params_, exec_key, kStepId, &input_tensor_, output_tensor_ptr);
|
||||
TF_CHECK_OK(gatherer.InitializeCollectiveContext(col_ctx));
|
||||
|
||||
// Run the all-gather.
|
||||
gatherer.Run([this](Status s) { status_ = s; });
|
||||
|
@ -507,10 +507,10 @@ class RingReducerTest : public ::testing::Test {
|
||||
string exec_key =
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
RingReducer reducer;
|
||||
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
&ctx, &op_params, col_params_, exec_key,
|
||||
kStepId, &tensor_, &tensor_);
|
||||
TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
|
||||
col_params_, exec_key, kStepId, &tensor_, &tensor_);
|
||||
TF_CHECK_OK(reducer.InitializeCollectiveContext(col_ctx));
|
||||
|
||||
// Run the all-reduce.
|
||||
reducer.Run([this](Status s) { status_ = s; });
|
||||
|
@ -327,7 +327,8 @@ class MockNcclReducer : public CollectiveImplementationInterface {
|
||||
Status InitializeCollectiveParams(CollectiveParams*) override {
|
||||
return Status::OK();
|
||||
}
|
||||
Status InitializeCollectiveContext(CollectiveContext*) override {
|
||||
Status InitializeCollectiveContext(
|
||||
std::shared_ptr<CollectiveContext>) override {
|
||||
return Status::OK();
|
||||
}
|
||||
Status InitializeCollectiveGroupRuntimeDetails(
|
||||
|
@ -399,7 +399,8 @@ class CollectiveImplementationInterface {
|
||||
// Called from CollectiveExecutor right before calling Run(). The
|
||||
// CollectiveContext passed in must outlive the CollectiveImplementation
|
||||
// object.
|
||||
virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0;
|
||||
virtual Status InitializeCollectiveContext(
|
||||
std::shared_ptr<CollectiveContext> col_ctx) = 0;
|
||||
|
||||
// Performs collective implementation specific group initialization. The
|
||||
// intention is to do group-specific initialization of runtime details for the
|
||||
|
@ -58,7 +58,8 @@ Status NcclBase::InitializeCollectiveParams(CollectiveParams* col_params) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NcclBase::InitializeCollectiveContext(CollectiveContext* col_ctx) {
|
||||
Status NcclBase::InitializeCollectiveContext(
|
||||
std::shared_ptr<CollectiveContext> col_ctx) {
|
||||
col_ctx_ = col_ctx;
|
||||
col_params_ = &col_ctx->col_params;
|
||||
return collective_util::InitializeDeviceAndLocality(
|
||||
|
@ -29,7 +29,8 @@ class NcclBase : public CollectiveImplementationInterface {
|
||||
Status InitializeCollectiveParams(CollectiveParams* col_params) override;
|
||||
|
||||
// Initializes the device objects and device localities.
|
||||
Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
|
||||
Status InitializeCollectiveContext(
|
||||
std::shared_ptr<CollectiveContext> col_ctx) override;
|
||||
|
||||
// Initialize nccl communicator key.
|
||||
Status InitializeCollectiveGroupRuntimeDetails(
|
||||
@ -40,7 +41,7 @@ class NcclBase : public CollectiveImplementationInterface {
|
||||
|
||||
const CollectiveType type_;
|
||||
const string name_;
|
||||
CollectiveContext* col_ctx_; // Not owned
|
||||
std::shared_ptr<CollectiveContext> col_ctx_;
|
||||
const CollectiveParams* col_params_; // Not owned
|
||||
};
|
||||
|
||||
|
@ -314,11 +314,11 @@ class NcclTestBase : public ::testing::Test {
|
||||
string exec_key =
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
NcclReducer reducer;
|
||||
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
/*OpKernelContext=*/&ctx, &op_params,
|
||||
col_params_, exec_key, kStepId,
|
||||
/*input=*/&input_, /*output=*/&input_);
|
||||
TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
|
||||
/*input=*/&input_, /*output=*/&input_);
|
||||
TF_CHECK_OK(reducer.InitializeCollectiveContext(col_ctx));
|
||||
Notification note;
|
||||
reducer.Run([this, ¬e](Status s) {
|
||||
status_ = s;
|
||||
@ -344,12 +344,12 @@ class NcclTestBase : public ::testing::Test {
|
||||
string exec_key =
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
NcclBroadcaster broadcaster;
|
||||
CollectiveContext col_ctx(
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
|
||||
/*input=*/col_params_.is_source ? &input_ : nullptr,
|
||||
/*output=*/&input_);
|
||||
TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx));
|
||||
TF_CHECK_OK(broadcaster.InitializeCollectiveContext(col_ctx));
|
||||
Notification note;
|
||||
broadcaster.Run([this, ¬e](Status s) {
|
||||
status_ = s;
|
||||
@ -383,12 +383,12 @@ class NcclTestBase : public ::testing::Test {
|
||||
string exec_key =
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
NcclGatherer gatherer;
|
||||
CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
/*OpKernelContext=*/&ctx, &op_params,
|
||||
col_params_, exec_key, kStepId,
|
||||
/*input=*/&input_,
|
||||
/*output=*/&output_);
|
||||
TF_CHECK_OK(gatherer.InitializeCollectiveContext(&col_ctx));
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
|
||||
/*input=*/&input_,
|
||||
/*output=*/&output_);
|
||||
TF_CHECK_OK(gatherer.InitializeCollectiveContext(col_ctx));
|
||||
Notification note;
|
||||
gatherer.Run([this, ¬e](Status s) {
|
||||
status_ = s;
|
||||
|
Loading…
Reference in New Issue
Block a user