diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc
index 1dfe2eed426..5d5100e7f2e 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.cc
+++ b/tensorflow/core/common_runtime/base_collective_executor.cc
@@ -271,13 +271,12 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
     DCHECK_EQ(nullptr, col_impl);
     return;
   }
-  CollectiveContext* col_ctx =
-      new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params,
-                            exec_key, step_id_, input, output);
+  auto col_ctx = std::make_shared<CollectiveContext>(
+      this, dev_mgr_, ctx, CtxParams(ctx), col_params, exec_key, step_id_,
+      input, output);
   status = col_impl->InitializeCollectiveContext(col_ctx);
   if (!status.ok()) {
     done_safe(status);
-    delete col_ctx;
     delete col_impl;
     return;
   }
@@ -293,7 +292,6 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
         profiler::TraceMeLevel::kInfo);
     col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
       done_safe(s);
-      delete col_ctx;
       delete col_impl;
     });
   });
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
index d4cb79e3c05..decf8b2ccb5 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
@@ -186,7 +186,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
 }
 
 Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
-    CollectiveContext* col_ctx) {
+    std::shared_ptr<CollectiveContext> col_ctx) {
   CHECK(col_ctx->dev_mgr);
   col_ctx_ = col_ctx;
   col_params_ = &col_ctx->col_params;
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h
index 38954e7dfaf..40ee3f82d48 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h
@@ -39,7 +39,8 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
 
   // Initializes members of CollectiveContext not yet initialized, i.e. device
   // and device_locality.  Also saves the CollectiveContext in this object.
-  Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
+  Status InitializeCollectiveContext(
+      std::shared_ptr<CollectiveContext> col_ctx) override;
 
   // No-op for hierarchical tree broadcaster.
   Status InitializeCollectiveGroupRuntimeDetails(
@@ -80,7 +81,7 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
   // Executes the hierarchical broadcast defined by this op.
   void RunTree();
 
-  CollectiveContext* col_ctx_;          // Not owned
+  std::shared_ptr<CollectiveContext> col_ctx_;
   const CollectiveParams* col_params_;  // Not owned
   StatusCallback done_;
   Status status_;
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
index 2006947258c..333a70adc27 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
@@ -670,10 +670,10 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
       string exec_key =
           strings::StrCat(col_params_.instance.instance_key, ":0:0");
       HierarchicalTreeBroadcaster broadcaster;
-      CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
-                                &ctx, &op_params, col_params_, exec_key,
-                                kStepId, input_tensor_ptr, output_tensor_ptr);
-      TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx));
+      auto col_ctx = std::make_shared<CollectiveContext>(
+          parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
+          col_params_, exec_key, kStepId, input_tensor_ptr, output_tensor_ptr);
+      TF_CHECK_OK(broadcaster.InitializeCollectiveContext(col_ctx));
 
       // Run the broadcast.
       broadcaster.Run([this](Status s) { status_ = s; });
diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc
index 3a1a84a376d..753f6ba982e 100644
--- a/tensorflow/core/common_runtime/ring_alg.cc
+++ b/tensorflow/core/common_runtime/ring_alg.cc
@@ -15,6 +15,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/ring_alg.h"
 
 #include <stdlib.h>
+
 #include <atomic>
 #include <functional>
 #include <utility>
@@ -240,7 +241,8 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
   return Status::OK();
 }
 
-Status RingAlg::InitializeCollectiveContext(CollectiveContext* col_ctx) {
+Status RingAlg::InitializeCollectiveContext(
+    std::shared_ptr<CollectiveContext> col_ctx) {
   DCHECK(col_ctx->dev_mgr);
   col_ctx_ = col_ctx;
   col_params_ = &col_ctx->col_params;
diff --git a/tensorflow/core/common_runtime/ring_alg.h b/tensorflow/core/common_runtime/ring_alg.h
index c2da62c86d7..3ccb07f6d5c 100644
--- a/tensorflow/core/common_runtime/ring_alg.h
+++ b/tensorflow/core/common_runtime/ring_alg.h
@@ -39,7 +39,8 @@ class RingAlg : public CollectiveImplementationInterface {
 
   // Initializes members of CollectiveContext not yet initialized, i.e. device
   // and device_locality.  Also saves the CollectiveContext in this object.
-  Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
+  Status InitializeCollectiveContext(
+      std::shared_ptr<CollectiveContext> col_ctx) override;
 
   // No-op for ring alg.
   Status InitializeCollectiveGroupRuntimeDetails(
@@ -108,7 +109,7 @@ class RingAlg : public CollectiveImplementationInterface {
 
   const CollectiveType type_;
   const string name_;
-  CollectiveContext* col_ctx_;          // Not owned
+  std::shared_ptr<CollectiveContext> col_ctx_;
   const CollectiveParams* col_params_;  // Not owned
   StatusCallback done_;
   int group_size_;
diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc
index 3af4890e3d3..124965b6c6a 100644
--- a/tensorflow/core/common_runtime/ring_gatherer_test.cc
+++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc
@@ -477,10 +477,10 @@ class RingGathererTest : public ::testing::Test {
       string exec_key =
           strings::StrCat(col_params_.instance.instance_key, ":0:0");
       RingGatherer gatherer;
-      CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
-                                &ctx, &op_params, col_params_, exec_key,
-                                kStepId, &input_tensor_, output_tensor_ptr);
-      TF_CHECK_OK(gatherer.InitializeCollectiveContext(&col_ctx));
+      auto col_ctx = std::make_shared<CollectiveContext>(
+          parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
+          col_params_, exec_key, kStepId, &input_tensor_, output_tensor_ptr);
+      TF_CHECK_OK(gatherer.InitializeCollectiveContext(col_ctx));
 
       // Run the all-gather.
       gatherer.Run([this](Status s) { status_ = s; });
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index 318d6e91afb..678153c3603 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -507,10 +507,10 @@ class RingReducerTest : public ::testing::Test {
       string exec_key =
           strings::StrCat(col_params_.instance.instance_key, ":0:0");
       RingReducer reducer;
-      CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
-                                &ctx, &op_params, col_params_, exec_key,
-                                kStepId, &tensor_, &tensor_);
-      TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
+      auto col_ctx = std::make_shared<CollectiveContext>(
+          parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
+          col_params_, exec_key, kStepId, &tensor_, &tensor_);
+      TF_CHECK_OK(reducer.InitializeCollectiveContext(col_ctx));
 
       // Run the all-reduce.
       reducer.Run([this](Status s) { status_ = s; });
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
index 13e61e55ee0..130a48e80d2 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
@@ -327,7 +327,8 @@ class MockNcclReducer : public CollectiveImplementationInterface {
   Status InitializeCollectiveParams(CollectiveParams*) override {
     return Status::OK();
   }
-  Status InitializeCollectiveContext(CollectiveContext*) override {
+  Status InitializeCollectiveContext(
+      std::shared_ptr<CollectiveContext>) override {
     return Status::OK();
   }
   Status InitializeCollectiveGroupRuntimeDetails(
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index 3726fde9809..24507b901a7 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -399,7 +399,8 @@ class CollectiveImplementationInterface {
   // Called from CollectiveExecutor right before calling Run().  The
   // CollectiveContext passed in must outlive the CollectiveImplementation
   // object.
-  virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0;
+  virtual Status InitializeCollectiveContext(
+      std::shared_ptr<CollectiveContext> col_ctx) = 0;
 
   // Performs collective implementation specific group initialization.  The
   // intention is to do group-specific initialization of runtime details for the
diff --git a/tensorflow/core/kernels/collective_nccl.cc b/tensorflow/core/kernels/collective_nccl.cc
index 013e06cc374..74ad24abfaa 100644
--- a/tensorflow/core/kernels/collective_nccl.cc
+++ b/tensorflow/core/kernels/collective_nccl.cc
@@ -58,7 +58,8 @@ Status NcclBase::InitializeCollectiveParams(CollectiveParams* col_params) {
   return Status::OK();
 }
 
-Status NcclBase::InitializeCollectiveContext(CollectiveContext* col_ctx) {
+Status NcclBase::InitializeCollectiveContext(
+    std::shared_ptr<CollectiveContext> col_ctx) {
   col_ctx_ = col_ctx;
   col_params_ = &col_ctx->col_params;
   return collective_util::InitializeDeviceAndLocality(
diff --git a/tensorflow/core/kernels/collective_nccl.h b/tensorflow/core/kernels/collective_nccl.h
index 5ef0d61aee5..b076272b6a5 100644
--- a/tensorflow/core/kernels/collective_nccl.h
+++ b/tensorflow/core/kernels/collective_nccl.h
@@ -29,7 +29,8 @@ class NcclBase : public CollectiveImplementationInterface {
   Status InitializeCollectiveParams(CollectiveParams* col_params) override;
 
   // Initializes the device objects and device localities.
-  Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
+  Status InitializeCollectiveContext(
+      std::shared_ptr<CollectiveContext> col_ctx) override;
 
   // Initialize nccl communicator key.
   Status InitializeCollectiveGroupRuntimeDetails(
@@ -40,7 +41,7 @@ class NcclBase : public CollectiveImplementationInterface {
 
   const CollectiveType type_;
   const string name_;
-  CollectiveContext* col_ctx_;          // Not owned
+  std::shared_ptr<CollectiveContext> col_ctx_;
   const CollectiveParams* col_params_;  // Not owned
 };
 
diff --git a/tensorflow/core/kernels/collective_nccl_test.cc b/tensorflow/core/kernels/collective_nccl_test.cc
index 8f3a958149b..ce4aca1cdcc 100644
--- a/tensorflow/core/kernels/collective_nccl_test.cc
+++ b/tensorflow/core/kernels/collective_nccl_test.cc
@@ -314,11 +314,11 @@ class NcclTestBase : public ::testing::Test {
       string exec_key =
           strings::StrCat(col_params_.instance.instance_key, ":0:0");
       NcclReducer reducer;
-      CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
-                                /*OpKernelContext=*/&ctx, &op_params,
-                                col_params_, exec_key, kStepId,
-                                /*input=*/&input_, /*output=*/&input_);
-      TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
+      auto col_ctx = std::make_shared<CollectiveContext>(
+          parent_->col_exec_, parent_->dev_mgr_.get(),
+          /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
+          /*input=*/&input_, /*output=*/&input_);
+      TF_CHECK_OK(reducer.InitializeCollectiveContext(col_ctx));
       Notification note;
       reducer.Run([this, &note](Status s) {
         status_ = s;
@@ -344,12 +344,12 @@ class NcclTestBase : public ::testing::Test {
       string exec_key =
           strings::StrCat(col_params_.instance.instance_key, ":0:0");
       NcclBroadcaster broadcaster;
-      CollectiveContext col_ctx(
+      auto col_ctx = std::make_shared<CollectiveContext>(
           parent_->col_exec_, parent_->dev_mgr_.get(),
           /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
           /*input=*/col_params_.is_source ? &input_ : nullptr,
           /*output=*/&input_);
-      TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx));
+      TF_CHECK_OK(broadcaster.InitializeCollectiveContext(col_ctx));
       Notification note;
       broadcaster.Run([this, &note](Status s) {
         status_ = s;
@@ -383,12 +383,12 @@ class NcclTestBase : public ::testing::Test {
       string exec_key =
           strings::StrCat(col_params_.instance.instance_key, ":0:0");
       NcclGatherer gatherer;
-      CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
-                                /*OpKernelContext=*/&ctx, &op_params,
-                                col_params_, exec_key, kStepId,
-                                /*input=*/&input_,
-                                /*output=*/&output_);
-      TF_CHECK_OK(gatherer.InitializeCollectiveContext(&col_ctx));
+      auto col_ctx = std::make_shared<CollectiveContext>(
+          parent_->col_exec_, parent_->dev_mgr_.get(),
+          /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
+          /*input=*/&input_,
+          /*output=*/&output_);
+      TF_CHECK_OK(gatherer.InitializeCollectiveContext(col_ctx));
       Notification note;
       gatherer.Run([this, &note](Status s) {
         status_ = s;