diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index 1dfe2eed426..5d5100e7f2e 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -271,13 +271,12 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, DCHECK_EQ(nullptr, col_impl); return; } - CollectiveContext* col_ctx = - new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params, - exec_key, step_id_, input, output); + auto col_ctx = std::make_shared<CollectiveContext>( + this, dev_mgr_, ctx, CtxParams(ctx), col_params, exec_key, step_id_, + input, output); status = col_impl->InitializeCollectiveContext(col_ctx); if (!status.ok()) { done_safe(status); - delete col_ctx; delete col_impl; return; } @@ -293,7 +292,6 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, profiler::TraceMeLevel::kInfo); col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) { done_safe(s); - delete col_ctx; delete col_impl; }); }); diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc index d4cb79e3c05..decf8b2ccb5 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc @@ -186,7 +186,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams( } Status HierarchicalTreeBroadcaster::InitializeCollectiveContext( - CollectiveContext* col_ctx) { + std::shared_ptr<CollectiveContext> col_ctx) { CHECK(col_ctx->dev_mgr); col_ctx_ = col_ctx; col_params_ = &col_ctx->col_params; diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h index 38954e7dfaf..40ee3f82d48 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h @@ -39,7 +39,8 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface { // Initializes members of CollectiveContext not yet initialized, i.e. device // and device_locality. Also saves the CollectiveContext in this object. - Status InitializeCollectiveContext(CollectiveContext* col_ctx) override; + Status InitializeCollectiveContext( + std::shared_ptr<CollectiveContext> col_ctx) override; // No-op for hierarchical tree broadcaster. Status InitializeCollectiveGroupRuntimeDetails( @@ -80,7 +81,7 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface { // Executes the hierarchical broadcast defined by this op. void RunTree(); - CollectiveContext* col_ctx_; // Not owned + std::shared_ptr<CollectiveContext> col_ctx_; const CollectiveParams* col_params_; // Not owned StatusCallback done_; Status status_; diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc index 2006947258c..333a70adc27 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc @@ -670,10 +670,10 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { string exec_key = strings::StrCat(col_params_.instance.instance_key, ":0:0"); HierarchicalTreeBroadcaster broadcaster; - CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), - &ctx, &op_params, col_params_, exec_key, - kStepId, input_tensor_ptr, output_tensor_ptr); - TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx)); + auto col_ctx = std::make_shared<CollectiveContext>( + parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params, + col_params_, exec_key, kStepId, input_tensor_ptr, output_tensor_ptr); + TF_CHECK_OK(broadcaster.InitializeCollectiveContext(col_ctx)); // Run the broadcast. broadcaster.Run([this](Status s) { status_ = s; }); diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc index 3a1a84a376d..753f6ba982e 100644 --- a/tensorflow/core/common_runtime/ring_alg.cc +++ b/tensorflow/core/common_runtime/ring_alg.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/ring_alg.h" #include <stdlib.h> + #include <atomic> #include <functional> #include <utility> @@ -240,7 +241,8 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { return Status::OK(); } -Status RingAlg::InitializeCollectiveContext(CollectiveContext* col_ctx) { +Status RingAlg::InitializeCollectiveContext( + std::shared_ptr<CollectiveContext> col_ctx) { DCHECK(col_ctx->dev_mgr); col_ctx_ = col_ctx; col_params_ = &col_ctx->col_params; diff --git a/tensorflow/core/common_runtime/ring_alg.h b/tensorflow/core/common_runtime/ring_alg.h index c2da62c86d7..3ccb07f6d5c 100644 --- a/tensorflow/core/common_runtime/ring_alg.h +++ b/tensorflow/core/common_runtime/ring_alg.h @@ -39,7 +39,8 @@ class RingAlg : public CollectiveImplementationInterface { // Initializes members of CollectiveContext not yet initialized, i.e. device // and device_locality. Also saves the CollectiveContext in this object. - Status InitializeCollectiveContext(CollectiveContext* col_ctx) override; + Status InitializeCollectiveContext( + std::shared_ptr<CollectiveContext> col_ctx) override; // No-op for ring alg. Status InitializeCollectiveGroupRuntimeDetails( @@ -108,7 +109,7 @@ class RingAlg : public CollectiveImplementationInterface { const CollectiveType type_; const string name_; - CollectiveContext* col_ctx_; // Not owned + std::shared_ptr<CollectiveContext> col_ctx_; const CollectiveParams* col_params_; // Not owned StatusCallback done_; int group_size_; diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc index 3af4890e3d3..124965b6c6a 100644 --- a/tensorflow/core/common_runtime/ring_gatherer_test.cc +++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc @@ -477,10 +477,10 @@ class RingGathererTest : public ::testing::Test { string exec_key = strings::StrCat(col_params_.instance.instance_key, ":0:0"); RingGatherer gatherer; - CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), - &ctx, &op_params, col_params_, exec_key, - kStepId, &input_tensor_, output_tensor_ptr); - TF_CHECK_OK(gatherer.InitializeCollectiveContext(&col_ctx)); + auto col_ctx = std::make_shared<CollectiveContext>( + parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params, + col_params_, exec_key, kStepId, &input_tensor_, output_tensor_ptr); + TF_CHECK_OK(gatherer.InitializeCollectiveContext(col_ctx)); // Run the all-gather. gatherer.Run([this](Status s) { status_ = s; }); diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc index 318d6e91afb..678153c3603 100644 --- a/tensorflow/core/common_runtime/ring_reducer_test.cc +++ b/tensorflow/core/common_runtime/ring_reducer_test.cc @@ -507,10 +507,10 @@ class RingReducerTest : public ::testing::Test { string exec_key = strings::StrCat(col_params_.instance.instance_key, ":0:0"); RingReducer reducer; - CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), - &ctx, &op_params, col_params_, exec_key, - kStepId, &tensor_, &tensor_); - TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx)); + auto col_ctx = std::make_shared<CollectiveContext>( + parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params, + col_params_, exec_key, kStepId, &tensor_, &tensor_); + TF_CHECK_OK(reducer.InitializeCollectiveContext(col_ctx)); // Run the all-reduce. reducer.Run([this](Status s) { status_ = s; }); diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index 13e61e55ee0..130a48e80d2 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -327,7 +327,8 @@ class MockNcclReducer : public CollectiveImplementationInterface { Status InitializeCollectiveParams(CollectiveParams*) override { return Status::OK(); } - Status InitializeCollectiveContext(CollectiveContext*) override { + Status InitializeCollectiveContext( + std::shared_ptr<CollectiveContext>) override { return Status::OK(); } Status InitializeCollectiveGroupRuntimeDetails( diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index 3726fde9809..24507b901a7 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -399,7 +399,8 @@ class CollectiveImplementationInterface { // Called from CollectiveExecutor right before calling Run(). The // CollectiveContext passed in must outlive the CollectiveImplementation // object. - virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0; + virtual Status InitializeCollectiveContext( + std::shared_ptr<CollectiveContext> col_ctx) = 0; // Performs collective implementation specific group initialization. The // intention is to do group-specific initialization of runtime details for the diff --git a/tensorflow/core/kernels/collective_nccl.cc b/tensorflow/core/kernels/collective_nccl.cc index 013e06cc374..74ad24abfaa 100644 --- a/tensorflow/core/kernels/collective_nccl.cc +++ b/tensorflow/core/kernels/collective_nccl.cc @@ -58,7 +58,8 @@ Status NcclBase::InitializeCollectiveParams(CollectiveParams* col_params) { return Status::OK(); } -Status NcclBase::InitializeCollectiveContext(CollectiveContext* col_ctx) { +Status NcclBase::InitializeCollectiveContext( + std::shared_ptr<CollectiveContext> col_ctx) { col_ctx_ = col_ctx; col_params_ = &col_ctx->col_params; return collective_util::InitializeDeviceAndLocality( diff --git a/tensorflow/core/kernels/collective_nccl.h b/tensorflow/core/kernels/collective_nccl.h index 5ef0d61aee5..b076272b6a5 100644 --- a/tensorflow/core/kernels/collective_nccl.h +++ b/tensorflow/core/kernels/collective_nccl.h @@ -29,7 +29,8 @@ class NcclBase : public CollectiveImplementationInterface { Status InitializeCollectiveParams(CollectiveParams* col_params) override; // Initializes the device objects and device localities. - Status InitializeCollectiveContext(CollectiveContext* col_ctx) override; + Status InitializeCollectiveContext( + std::shared_ptr<CollectiveContext> col_ctx) override; // Initialize nccl communicator key. Status InitializeCollectiveGroupRuntimeDetails( @@ -40,7 +41,7 @@ class NcclBase : public CollectiveImplementationInterface { const CollectiveType type_; const string name_; - CollectiveContext* col_ctx_; // Not owned + std::shared_ptr<CollectiveContext> col_ctx_; const CollectiveParams* col_params_; // Not owned }; diff --git a/tensorflow/core/kernels/collective_nccl_test.cc b/tensorflow/core/kernels/collective_nccl_test.cc index 8f3a958149b..ce4aca1cdcc 100644 --- a/tensorflow/core/kernels/collective_nccl_test.cc +++ b/tensorflow/core/kernels/collective_nccl_test.cc @@ -314,11 +314,11 @@ class NcclTestBase : public ::testing::Test { string exec_key = strings::StrCat(col_params_.instance.instance_key, ":0:0"); NcclReducer reducer; - CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), - /*OpKernelContext=*/&ctx, &op_params, - col_params_, exec_key, kStepId, - /*input=*/&input_, /*output=*/&input_); - TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx)); + auto col_ctx = std::make_shared<CollectiveContext>( + parent_->col_exec_, parent_->dev_mgr_.get(), + /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId, + /*input=*/&input_, /*output=*/&input_); + TF_CHECK_OK(reducer.InitializeCollectiveContext(col_ctx)); Notification note; reducer.Run([this, ¬e](Status s) { status_ = s; @@ -344,12 +344,12 @@ class NcclTestBase : public ::testing::Test { string exec_key = strings::StrCat(col_params_.instance.instance_key, ":0:0"); NcclBroadcaster broadcaster; - CollectiveContext col_ctx( + auto col_ctx = std::make_shared<CollectiveContext>( parent_->col_exec_, parent_->dev_mgr_.get(), /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId, /*input=*/col_params_.is_source ? &input_ : nullptr, /*output=*/&input_); - TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx)); + TF_CHECK_OK(broadcaster.InitializeCollectiveContext(col_ctx)); Notification note; broadcaster.Run([this, ¬e](Status s) { status_ = s; @@ -383,12 +383,12 @@ class NcclTestBase : public ::testing::Test { string exec_key = strings::StrCat(col_params_.instance.instance_key, ":0:0"); NcclGatherer gatherer; - CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(), - /*OpKernelContext=*/&ctx, &op_params, - col_params_, exec_key, kStepId, - /*input=*/&input_, - /*output=*/&output_); - TF_CHECK_OK(gatherer.InitializeCollectiveContext(&col_ctx)); + auto col_ctx = std::make_shared<CollectiveContext>( + parent_->col_exec_, parent_->dev_mgr_.get(), + /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId, + /*input=*/&input_, + /*output=*/&output_); + TF_CHECK_OK(gatherer.InitializeCollectiveContext(col_ctx)); Notification note; gatherer.Run([this, ¬e](Status s) { status_ = s;