From 18eaf4e8f1769d2bb05ed354a8b6198e49aadfcc Mon Sep 17 00:00:00 2001
From: Ayush Dubey <ayushd@google.com>
Date: Thu, 4 Feb 2021 09:40:21 -0800
Subject: [PATCH] Ensure that `CollectiveParams` outlives all references to it.

Before this change, it was possible to access a `const CollectiveParams&` after
it was destroyed.  For example, the call to `UnblockDependencies` in
`NcclCommunicator::Enqueue` raced with the done_callback of the collective
participant.

This change makes `CollectiveParams` a refcounted object, and holds references
everywhere it may be accessed.

PiperOrigin-RevId: 355646163
Change-Id: I7fd164afe8c1c9aa1c3b77a988930a0624977c7c
---
 .../base_collective_executor.cc               |  16 +-
 .../common_runtime/base_collective_executor.h |   2 +-
 .../collective_param_resolver_local.cc        |  16 +-
 .../collective_param_resolver_local.h         |   6 +-
 .../collective_param_resolver_local_test.cc   | 129 +++++----
 .../hierarchical_tree_broadcaster.cc          |   2 +-
 .../hierarchical_tree_broadcaster_test.cc     | 229 ++++++++--------
 tensorflow/core/common_runtime/permuter.cc    |   2 +-
 .../core/common_runtime/permuter_test.cc      |  68 ++---
 tensorflow/core/common_runtime/ring_alg.cc    |   2 +-
 .../core/common_runtime/ring_gatherer_test.cc | 141 +++++-----
 .../core/common_runtime/ring_reducer_test.cc  | 175 +++++++------
 .../collective_param_resolver_distributed.cc  |  15 +-
 ...lective_param_resolver_distributed_test.cc |  60 +++--
 tensorflow/core/framework/collective.cc       |   4 +-
 tensorflow/core/framework/collective.h        |   8 +-
 tensorflow/core/kernels/collective_nccl.cc    |   2 +-
 .../core/kernels/collective_nccl_reducer.cc   |   3 +
 tensorflow/core/kernels/collective_ops.cc     | 247 +++++++++---------
 .../core/nccl/collective_communicator.cc      |  44 ++--
 20 files changed, 627 insertions(+), 544 deletions(-)

diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc
index a2cfce1111c..c365cddae2d 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.cc
+++ b/tensorflow/core/common_runtime/base_collective_executor.cc
@@ -264,7 +264,7 @@ Status BaseCollectiveExecutor::GetStatus(const Status& s) {
 }
 
 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
-                                          const CollectiveParams& col_params,
+                                          const CollectiveParams* col_params,
                                           const string& exec_key,
                                           StatusCallback done) {
   // See CompleteParamsAsync() how done() and the timeout callback interacts.
@@ -281,7 +281,7 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
     }
   };
   auto timeout_microseconds = static_cast<int64>(
-      col_params.instance.impl_details.timeout_seconds * 1'000'000);
+      col_params->instance.impl_details.timeout_seconds * 1'000'000);
   if (timeout_microseconds > 0) {
     // TODO(xldrx): Share the timeout watchdog thread among collectives.
     SchedNonBlockingClosureAfter(
@@ -297,15 +297,15 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
   }
 
   Tensor* output = ctx->mutable_output(0);
-  const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
-                         col_params.instance.type == GATHER_COLLECTIVE ||
-                         col_params.instance.type == PERMUTE_COLLECTIVE ||
-                         (col_params.instance.type == BROADCAST_COLLECTIVE &&
-                          col_params.is_source))
+  const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE ||
+                         col_params->instance.type == GATHER_COLLECTIVE ||
+                         col_params->instance.type == PERMUTE_COLLECTIVE ||
+                         (col_params->instance.type == BROADCAST_COLLECTIVE &&
+                          col_params->is_source))
                             ? &ctx->input(0)
                             : nullptr;
   CollectiveImplementationInterface* col_impl = nullptr;
-  Status status = CreateCollective(col_params, &col_impl);
+  Status status = CreateCollective(*col_params, &col_impl);
   if (!status.ok()) {
     done_safe(status);
     DCHECK_EQ(nullptr, col_impl);
diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h
index 142c825df55..8dd0a55ef18 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.h
+++ b/tensorflow/core/common_runtime/base_collective_executor.h
@@ -110,7 +110,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
 
   void StartAbort(const Status& s) override TF_LOCKS_EXCLUDED(status_mu_);
 
-  void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params,
+  void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams* col_params,
                     const string& exec_key, StatusCallback done) override;
 
   void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index c42b6a61f57..9bc3b2e1b40 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -513,14 +513,14 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
 
 void CollectiveParamResolverLocal::InitInstanceSharedParams(
     const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
-  ir->shared.instance = cp->instance;
-  ir->shared.default_rank = -1;
+  ir->shared->instance = cp->instance;
+  ir->shared->default_rank = -1;
 
   // Set is_local and task_names in *shared prior to invoking
   // GetDeviceAttributesAsync.  In a distributed context this function can be
   // called by a derived class, some of the devices may be non-local and
   // GetDeviceAttributesAsync will use those fields to launch RPCs.
-  CompleteTaskIsLocal(task_name_, &ir->shared);
+  CompleteTaskIsLocal(task_name_, ir->shared);
 }
 
 // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
@@ -662,11 +662,11 @@ void CollectiveParamResolverLocal::CompleteInstanceLocal(
   if (!created_irec) {
     // Check that the preexisting IRec is consistent with the params passed into
     // this invocation.
-    if (ir->shared.instance.type != cp->instance.type ||
-        ir->shared.instance.data_type != cp->instance.data_type) {
+    if (ir->shared->instance.type != cp->instance.type ||
+        ir->shared->instance.data_type != cp->instance.data_type) {
       done(errors::Internal("Collective instance ", cp->instance.instance_key,
-                            " expected type ", ir->shared.instance.type,
-                            " and data_type ", ir->shared.instance.data_type,
+                            " expected type ", ir->shared->instance.type,
+                            " and data_type ", ir->shared->instance.data_type,
                             " but got type ", cp->instance.type,
                             " and data_type ", cp->instance.data_type));
       return;
@@ -686,7 +686,7 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
     status = ir->status;
     if (status.ok()) {
       // custom operator= does a deep copy.
-      cp->instance = ir->shared.instance;
+      cp->instance = ir->shared->instance;
     }
   }
   if (!status.ok()) {
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 63a3bf2e063..5a7cd54a0de 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -98,7 +98,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
   struct InstanceRec {
     mutex mu;
     // Values to be shared by all instances, constant after initialization.
-    CollectiveParams shared;
+    CollectiveParams* shared;
     // If an error occurs during initialization this structure stays in the
     // table with a non-OK status. Purging the table and restarting needs to be
     // done at a higher level.
@@ -113,7 +113,9 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
     std::vector<bool> known TF_GUARDED_BY(mu);
     std::vector<IRConsumer> known_waiters TF_GUARDED_BY(mu);
 
-    InstanceRec() : source_rank(-1), known_count(0) {}
+    InstanceRec()
+        : shared(new CollectiveParams()), source_rank(-1), known_count(0) {}
+    ~InstanceRec() { shared->Unref(); }
   };
 
   // Find the InstanceRec with the same instance_key as cp.  If it doesn't
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
index e1ac46f2e53..611d6bbff50 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
@@ -161,11 +161,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
 }
 
 TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
-  CollectiveParams cps[NUM_DEVS];
+  CollectiveParams* cps[NUM_DEVS];
   Status statuses[NUM_DEVS];
   Notification note[NUM_DEVS];
   for (int i = 0; i < NUM_DEVS; ++i) {
-    CollectiveParams* cp = &cps[i];
+    cps[i] = new CollectiveParams();
+    CollectiveParams* cp = cps[i];
     cp->group.group_key = 1;
     cp->group.group_size = 3;
     cp->group.device_type = DeviceType("CPU");
@@ -192,17 +193,18 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
   }
   for (int i = 0; i < NUM_DEVS; ++i) {
     TF_ASSERT_OK(statuses[i]);
-    ASSERT_EQ(cps[i].group.device_names.size(), 3);
+    ASSERT_EQ(cps[i]->group.device_names.size(), 3);
     for (int j = 0; j < NUM_DEVS; ++j) {
       EXPECT_EQ(
           strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
-          cps[i].group.device_names[j]);
-      EXPECT_TRUE(cps[i].task.is_local[j]);
+          cps[i]->group.device_names[j]);
+      EXPECT_TRUE(cps[i]->task.is_local[j]);
     }
-    EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
-    EXPECT_FALSE(cps[i].is_source);
-    EXPECT_EQ(cps[i].default_rank, i);
-    EXPECT_TRUE(cps[i].group.same_num_devices_per_task);
+    EXPECT_EQ(cps[i]->instance.impl_details.subdiv_source_rank.size(), 0);
+    EXPECT_FALSE(cps[i]->is_source);
+    EXPECT_EQ(cps[i]->default_rank, i);
+    EXPECT_TRUE(cps[i]->group.same_num_devices_per_task);
+    cps[i]->Unref();
   }
 }
 
@@ -223,11 +225,12 @@ void InitializeCollectiveParamsForBroadcast(int instance_key, int device_idx,
 
 TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
   constexpr int kInstanceKey = 5;
-  CollectiveParams cps[NUM_DEVS];
+  CollectiveParams* cps[NUM_DEVS];
   Status statuses[NUM_DEVS];
   Notification note[NUM_DEVS];
   for (int i = 0; i < NUM_DEVS; ++i) {
-    CollectiveParams* cp = &cps[i];
+    cps[i] = new CollectiveParams();
+    CollectiveParams* cp = cps[i];
     InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp);
     Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
       string device =
@@ -245,16 +248,17 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
   }
   for (int i = 0; i < NUM_DEVS; ++i) {
     TF_ASSERT_OK(statuses[i]);
-    ASSERT_EQ(cps[i].group.device_names.size(), 3);
+    ASSERT_EQ(cps[i]->group.device_names.size(), 3);
     for (int j = 0; j < NUM_DEVS; ++j) {
       EXPECT_EQ(
           strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
-          cps[i].group.device_names[j]);
-      EXPECT_TRUE(cps[i].task.is_local[j]);
+          cps[i]->group.device_names[j]);
+      EXPECT_TRUE(cps[i]->task.is_local[j]);
     }
-    EXPECT_EQ(cps[i].is_source, (i == 1));
-    EXPECT_EQ(cps[i].default_rank, i);
-    EXPECT_TRUE(cps[i].group.same_num_devices_per_task);
+    EXPECT_EQ(cps[i]->is_source, (i == 1));
+    EXPECT_EQ(cps[i]->default_rank, i);
+    EXPECT_TRUE(cps[i]->group.same_num_devices_per_task);
+    cps[i]->Unref();
   }
 }
 
@@ -263,11 +267,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
 // get an internal error from param resolution.
 TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) {
   constexpr int kInstanceKey = 8;
-  CollectiveParams cps[NUM_DEVS];
+  CollectiveParams* cps[NUM_DEVS];
   Status statuses[NUM_DEVS];
   Notification note[NUM_DEVS];
   for (int i = 0; i < NUM_DEVS; ++i) {
-    CollectiveParams* cp = &cps[i];
+    cps[i] = new CollectiveParams();
+    CollectiveParams* cp = cps[i];
     InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp);
     Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
       string device =
@@ -291,27 +296,28 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) {
                   " found no source for broadcast.  This could mean that there"
                   " were group_size=",
                   NUM_DEVS, " BcastRecvs but no BcastSend."));
+    cps[i]->Unref();
   }
 }
 
-CollectiveParams MakeCollectiveParams(int group_key, int instance_key,
-                                      bool is_source) {
-  CollectiveParams cp;
-  cp.group.group_key = group_key;
-  cp.group.group_size = NUM_DEVS;
-  cp.group.device_type = DeviceType("CPU");
-  cp.group.num_tasks = 1;
-  cp.instance.instance_key = instance_key;
+CollectiveParams* MakeCollectiveParams(int group_key, int instance_key,
+                                       bool is_source) {
+  auto* cp = new CollectiveParams();
+  cp->group.group_key = group_key;
+  cp->group.group_size = NUM_DEVS;
+  cp->group.device_type = DeviceType("CPU");
+  cp->group.num_tasks = 1;
+  cp->instance.instance_key = instance_key;
   // CompleteInstanceLocal only waits for the group for broadcasts.
   // Testing with broadcasts yields better coverage.
-  cp.instance.type = BROADCAST_COLLECTIVE;
-  cp.is_source = is_source;
+  cp->instance.type = BROADCAST_COLLECTIVE;
+  cp->is_source = is_source;
   return cp;
 }
 
 TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
   CancellationManager cancel_mgr;
-  std::vector<CollectiveParams> cp(NUM_DEVS - 1);
+  std::vector<CollectiveParams*> cp(NUM_DEVS - 1);
   BlockingCounter start(NUM_DEVS - 1);
   BlockingCounter done(NUM_DEVS - 1);
   for (int i = 0; i < NUM_DEVS - 1; ++i) {
@@ -320,11 +326,12 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
           strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
       cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100,
                                    /*is_source*/ i == 0);
-      prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
-                                &cancel_mgr, [&done](const Status& s) {
+      prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr,
+                                [&done, cp = cp[i]](const Status& s) {
                                   EXPECT_EQ(s.code(), error::ABORTED);
                                   EXPECT_EQ(s.error_message(), "__aborted__");
                                   done.DecrementCount();
+                                  cp->Unref();
                                 });
       start.DecrementCount();
     });
@@ -336,7 +343,7 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
 
 TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
   CancellationManager cancel_mgr;
-  std::vector<CollectiveParams> cp(NUM_DEVS);
+  std::vector<CollectiveParams*> cp(NUM_DEVS);
   int group_key = 100;
   int instance_key = 100;
   // First do a normal CompleteParamsAsync to complete the group;
@@ -349,10 +356,12 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
             strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
         cp[i] = MakeCollectiveParams(group_key, instance_key,
                                      /*is_source*/ i == 0);
-        prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
-                                  &cancel_mgr, [&done](const Status& s) {
+        prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i],
+                                  &cancel_mgr,
+                                  [&done, cp = cp[i]](const Status& s) {
                                     EXPECT_EQ(s.code(), error::OK);
                                     done.DecrementCount();
+                                    cp->Unref();
                                   });
       });
     }
@@ -361,21 +370,21 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
   BlockingCounter start(NUM_DEVS - 1);
   BlockingCounter done(NUM_DEVS - 1);
   for (int i = 0; i < NUM_DEVS - 1; ++i) {
-    Env::Default()->SchedClosure(
-        [this, group_key, instance_key, i, &cancel_mgr, &cp, &start, &done] {
-          string device =
-              strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
-          cp[i] = MakeCollectiveParams(group_key, instance_key + 1,
-                                       /*is_source*/ i == 0);
-          prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
-                                    &cancel_mgr, [&done](const Status& s) {
-                                      EXPECT_EQ(s.code(), error::ABORTED);
-                                      EXPECT_EQ(s.error_message(),
-                                                "__aborted__");
-                                      done.DecrementCount();
-                                    });
-          start.DecrementCount();
-        });
+    Env::Default()->SchedClosure([this, group_key, instance_key, i, &cancel_mgr,
+                                  &cp, &start, &done] {
+      string device =
+          strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
+      cp[i] = MakeCollectiveParams(group_key, instance_key + 1,
+                                   /*is_source*/ i == 0);
+      prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr,
+                                [&done, cp = cp[i]](const Status& s) {
+                                  EXPECT_EQ(s.code(), error::ABORTED);
+                                  EXPECT_EQ(s.error_message(), "__aborted__");
+                                  done.DecrementCount();
+                                  cp->Unref();
+                                });
+      start.DecrementCount();
+    });
   }
   start.Wait();
   prl_->StartAbort(Status(error::ABORTED, "__aborted__"));
@@ -388,7 +397,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
   int instance_key = 100;
   // First do a normal CompleteParamsAsync to complete the group;
   {
-    std::vector<CollectiveParams> cp(NUM_DEVS);
+    std::vector<CollectiveParams*> cp(NUM_DEVS);
     BlockingCounter done(NUM_DEVS);
     for (int i = 0; i < NUM_DEVS; ++i) {
       Env::Default()->SchedClosure([this, group_key, instance_key, i,
@@ -397,10 +406,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
             strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
         cp[i] = MakeCollectiveParams(group_key, instance_key,
                                      /*is_source*/ i == 0);
-        prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
-                                  &cancel_mgr, [&done](const Status& s) {
+        prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i],
+                                  &cancel_mgr,
+                                  [&done, cp = cp[i]](const Status& s) {
                                     EXPECT_EQ(s.code(), error::OK);
                                     done.DecrementCount();
+                                    cp->Unref();
                                   });
       });
     }
@@ -411,9 +422,10 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
   auto complete_params = [this, &cancel_mgr](int group_key, int instance_key) {
     string device = "/job:localhost/replica:0/task:0/device:CPU:0";
     Notification done;
-    auto cp = MakeCollectiveParams(group_key, instance_key,
-                                   /*is_source*/ true);
-    prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, &cancel_mgr,
+    auto* cp = MakeCollectiveParams(group_key, instance_key,
+                                    /*is_source*/ true);
+    core::ScopedUnref unref(cp);
+    prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, &cancel_mgr,
                               [&done](const Status& s) {
                                 EXPECT_EQ(s.code(), error::ABORTED);
                                 EXPECT_EQ(s.error_message(), "__aborted__");
@@ -449,16 +461,17 @@ TEST_F(CollectiveParamResolverLocalTest, AbortNormalCompleteParamsAsync) {
             while (true) {
               Status status;
               Notification n;
-              auto cp =
+              auto* cp =
                   MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key,
                                        /*is_source*/ i == 0);
-              prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp,
+              prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp,
                                         &cancel_mgr,
                                         [&status, &n](const Status& s) {
                                           status = s;
                                           n.Notify();
                                         });
               n.WaitForNotification();
+              cp->Unref();
               // The status should be either OK or the aborted status.
               if (!status.ok()) {
                 EXPECT_EQ(status.code(), error::ABORTED);
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
index e78fbef13de..ebe568d6bac 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
@@ -188,7 +188,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
     std::shared_ptr<CollectiveContext> col_ctx) {
   CHECK(col_ctx->dev_mgr);
   col_ctx_ = col_ctx;
-  col_params_ = &col_ctx->col_params;
+  col_params_ = col_ctx->col_params;
   return collective_util::InitializeDeviceAndLocality(
       col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
       &col_ctx->device_locality);
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
index 97a1d0b46ce..378dc459da1 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
@@ -56,23 +56,24 @@ class TrivialTest : public ::testing::Test {
 // R = tested rank
 // RF = receive-from rank
 // ST = send_to rank vector
-#define DEF_TL_TEST(D, S, R, RF, ST)                                 \
-  TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) {   \
-    CollectiveParams cp;                                             \
-    cp.group.group_size = D;                                         \
-    cp.instance.impl_details.subdiv_source_rank = {S};               \
-    cp.instance.impl_details.subdiv_permutations.push_back(          \
-        std::vector<int>(D, 0));                                     \
-    cp.subdiv_rank = {R};                                            \
-    cp.is_source = (S == R);                                         \
-    EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(cp, 0)); \
-    std::vector<int> expected = ST;                                  \
-    std::vector<int> send_to;                                        \
-    HierarchicalTreeBroadcaster::TreeSendTo(cp, 0, &send_to);        \
-    ASSERT_EQ(expected.size(), send_to.size());                      \
-    for (int i = 0; i < expected.size(); ++i) {                      \
-      EXPECT_EQ(expected[i], send_to[i]);                            \
-    }                                                                \
+#define DEF_TL_TEST(D, S, R, RF, ST)                                  \
+  TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) {    \
+    auto* cp = new CollectiveParams();                                \
+    core::ScopedUnref unref(cp);                                      \
+    cp->group.group_size = D;                                         \
+    cp->instance.impl_details.subdiv_source_rank = {S};               \
+    cp->instance.impl_details.subdiv_permutations.push_back(          \
+        std::vector<int>(D, 0));                                      \
+    cp->subdiv_rank = {R};                                            \
+    cp->is_source = (S == R);                                         \
+    EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(*cp, 0)); \
+    std::vector<int> expected = ST;                                   \
+    std::vector<int> send_to;                                         \
+    HierarchicalTreeBroadcaster::TreeSendTo(*cp, 0, &send_to);        \
+    ASSERT_EQ(expected.size(), send_to.size());                       \
+    for (int i = 0; i < expected.size(); ++i) {                       \
+      EXPECT_EQ(expected[i], send_to[i]);                             \
+    }                                                                 \
   }
 
 #define V(...) std::vector<int>({__VA_ARGS__})
@@ -196,12 +197,14 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
 
 class HierarchicalTreeBroadcasterTest : public ::testing::Test {
  protected:
-  HierarchicalTreeBroadcasterTest() : device_type_(DEVICE_CPU) {}
+  HierarchicalTreeBroadcasterTest()
+      : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
 
   ~HierarchicalTreeBroadcasterTest() override {
     stop_ = true;
     for (auto i : instances_) delete i;
     if (col_exec_) col_exec_->Unref();
+    if (col_params_) col_params_->Unref();
   }
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -262,30 +265,31 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
     col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
                                            dev_mgr_.get(),
                                            gpu_ring_order_.get(), work_queue_);
-    col_params_.name = "test_collective";
-    col_params_.instance.data_type = dtype;
+    col_params_ = new CollectiveParams();
+    col_params_->name = "test_collective";
+    col_params_->instance.data_type = dtype;
     static const int kGroupKey = 6;
-    col_params_.group.group_key = kGroupKey;
+    col_params_->group.group_key = kGroupKey;
     static const int kInstanceKey = 18;
-    col_params_.instance.instance_key = kInstanceKey;
-    col_params_.group.device_type = device_type;
-    col_params_.group.group_size = num_workers * num_devices_per_worker;
-    col_params_.instance.impl_details.subdiv_offsets.clear();
-    col_params_.instance.type = BROADCAST_COLLECTIVE;
+    col_params_->instance.instance_key = kInstanceKey;
+    col_params_->group.device_type = device_type;
+    col_params_->group.group_size = num_workers * num_devices_per_worker;
+    col_params_->instance.impl_details.subdiv_offsets.clear();
+    col_params_->instance.type = BROADCAST_COLLECTIVE;
 
     int num_subdivs = num_workers + (num_workers > 1 ? 1 : 0);
     VLOG(2) << "#subdiv=" << num_subdivs;
-    col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
-    col_params_.subdiv_rank.resize(num_subdivs);
+    col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs);
+    col_params_->subdiv_rank.resize(num_subdivs);
 
     // Inter-machine broadcast.
     int subdiv_i = 0;
     if (num_workers > 1) {
-      col_params_.instance.impl_details.subdiv_permutations[subdiv_i].resize(
+      col_params_->instance.impl_details.subdiv_permutations[subdiv_i].resize(
           total_num_devices, -1);
       for (int i = 0, rank = 0; i < total_num_devices; i++) {
         if (i % num_devices_per_worker == 0) {
-          col_params_.instance.impl_details
+          col_params_->instance.impl_details
               .subdiv_permutations[subdiv_i][rank] = i;
           rank++;
         }
@@ -293,7 +297,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
       if (VLOG_IS_ON(2)) {
         string sp_buf;
         for (int p :
-             col_params_.instance.impl_details.subdiv_permutations[subdiv_i])
+             col_params_->instance.impl_details.subdiv_permutations[subdiv_i])
           strings::StrAppend(&sp_buf, p, ", ");
         VLOG(2) << "subdiv_i=" << subdiv_i << " perm=" << sp_buf;
       }
@@ -301,22 +305,22 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
     }
     // Intra-machine broadcast.
     for (int i = 0; subdiv_i < num_subdivs; i++, subdiv_i++) {
-      col_params_.instance.impl_details.subdiv_permutations[subdiv_i].resize(
+      col_params_->instance.impl_details.subdiv_permutations[subdiv_i].resize(
           total_num_devices, -1);
       int perm_i_base = i * num_devices_per_worker;
       VLOG(2) << "subdiv_i=" << subdiv_i << " i=" << i
               << " perm_i_base=" << perm_i_base << " subdiv_perms.size="
-              << col_params_.instance.impl_details.subdiv_permutations.size();
+              << col_params_->instance.impl_details.subdiv_permutations.size();
       // subdiv for worker i.
       for (int j = perm_i_base, rank = 0;
            j < perm_i_base + num_devices_per_worker; j++, rank++) {
-        col_params_.instance.impl_details.subdiv_permutations[subdiv_i][rank] =
+        col_params_->instance.impl_details.subdiv_permutations[subdiv_i][rank] =
             j;
       }
       if (VLOG_IS_ON(2)) {
         string sp_buf;
         for (int p :
-             col_params_.instance.impl_details.subdiv_permutations[subdiv_i])
+             col_params_->instance.impl_details.subdiv_permutations[subdiv_i])
           strings::StrAppend(&sp_buf, p, ", ");
         VLOG(2) << "subdiv_i=" << subdiv_i << " perm=" << sp_buf;
       }
@@ -333,16 +337,16 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
           dev_name = strings::StrCat(task_name, "/device:CPU:", di);
         }
         VLOG(2) << "dev=" << dev_name;
-        col_params_.group.device_names.push_back(dev_name);
-        col_params_.group.task_names.push_back(task_name);
-        col_params_.task.is_local.push_back(true);
+        col_params_->group.device_names.push_back(dev_name);
+        col_params_->group.task_names.push_back(task_name);
+        col_params_->task.is_local.push_back(true);
       }
     }
     for (int wi = 0; wi < num_workers; wi++) {
       for (int di = 0; di < num_devices_per_worker; di++) {
         int default_rank = wi * num_devices_per_worker + di;
         instances_.push_back(new DeviceInstance(
-            default_rank, col_params_.group.device_names[default_rank],
+            default_rank, col_params_->group.device_names[default_rank],
             device_type, this));
       }
     }
@@ -435,7 +439,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 
     // Copy the expected value from the broadcast source tensor
     std::vector<T> expected(tensor_len, 0.0);
-    const CollectiveParams& cp = instances_[0]->col_params_;
+    const CollectiveParams& cp = *instances_[0]->col_params_;
     int broadcast_dev_id =
         cp.instance.impl_details.subdiv_permutations
             [0][cp.instance.impl_details.subdiv_source_rank[0]];
@@ -558,27 +562,29 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
         : parent_(parent),
           dev_name_(dev_name),
           device_type_(device_type),
-          rank_(rank) {
+          rank_(rank),
+          col_params_(new CollectiveParams()) {
       TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_));
-      col_params_.name = parent_->col_params_.name;
-      col_params_.instance.data_type = parent_->col_params_.instance.data_type;
-      col_params_.group = parent_->col_params_.group;
-      col_params_.instance.instance_key =
-          parent_->col_params_.instance.instance_key;
-      col_params_.task.is_local = parent_->col_params_.task.is_local;
-      col_params_.instance.impl_details.subdiv_permutations =
-          parent_->col_params_.instance.impl_details.subdiv_permutations;
-      col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
+      col_params_->name = parent_->col_params_->name;
+      col_params_->instance.data_type =
+          parent_->col_params_->instance.data_type;
+      col_params_->group = parent_->col_params_->group;
+      col_params_->instance.instance_key =
+          parent_->col_params_->instance.instance_key;
+      col_params_->task.is_local = parent_->col_params_->task.is_local;
+      col_params_->instance.impl_details.subdiv_permutations =
+          parent_->col_params_->instance.impl_details.subdiv_permutations;
+      col_params_->subdiv_rank = parent_->col_params_->subdiv_rank;
 
-      int group_size = col_params_.group.group_size;
-      CHECK_EQ(group_size, col_params_.group.device_names.size());
+      int group_size = col_params_->group.group_size;
+      CHECK_EQ(group_size, col_params_->group.device_names.size());
       // Default rank is order in device_names.
-      col_params_.default_rank = rank;
+      col_params_->default_rank = rank;
 
-      auto& impl = col_params_.instance.impl_details;
+      auto& impl = col_params_->instance.impl_details;
       size_t num_subdivs = impl.subdiv_permutations.size();
       impl.subdiv_source_rank.resize(num_subdivs, 0);
-      col_params_.subdiv_rank.resize(num_subdivs);
+      col_params_->subdiv_rank.resize(num_subdivs);
       for (size_t si = 0; si < num_subdivs; si++) {
         int perm_rank = -1;
         for (int i = 0; i < group_size; i++) {
@@ -587,18 +593,20 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
             break;
           }
         }
-        col_params_.subdiv_rank[si] = perm_rank;
+        col_params_->subdiv_rank[si] = perm_rank;
       }
       string rank_buf;
-      for (int r : col_params_.subdiv_rank) {
+      for (int r : col_params_->subdiv_rank) {
         strings::StrAppend(&rank_buf, r, ", ");
       }
       VLOG(1) << "default=" << rank << " subdiv_ranks=" << rank_buf;
 
-      col_params_.is_source =
-          col_params_.subdiv_rank[0] == impl.subdiv_source_rank[0];
+      col_params_->is_source =
+          col_params_->subdiv_rank[0] == impl.subdiv_source_rank[0];
     }
 
+    ~DeviceInstance() { col_params_->Unref(); }
+
     void InitTensor(DataType dtype, const TensorShape& shape,
                     const InitFunc& f) {
       tensor_ =
@@ -641,22 +649,22 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
       op_params.op_device_context = dev_ctx;
       int forward_from[] = {OpKernelContext::Params::kNeverForward};
       if (forward_input) forward_from[0] = 0;
-      if (col_params_.is_source) {
+      if (col_params_->is_source) {
         op_params.forward_from_array = &forward_from[0];
       }
       AllocatorAttributes generic_alloc_attr;
       op_params.output_attr_array = &generic_alloc_attr;
       std::unique_ptr<OpKernel> op =
-          col_params_.is_source
-              ? parent_->GetCollectiveBcastSend(col_params_, &tensor_,
+          col_params_->is_source
+              ? parent_->GetCollectiveBcastSend(*col_params_, &tensor_,
                                                 DEVICE_CPU, device_)
-              : parent_->GetCollectiveBcastRecv(col_params_, tensor_.shape(),
+              : parent_->GetCollectiveBcastRecv(*col_params_, tensor_.shape(),
                                                 DEVICE_CPU, device_);
       op_params.op_kernel = op.get();
       OpKernelContext ctx(&op_params, 1);
 
       Tensor* output_tensor_ptr = nullptr;
-      if (col_params_.is_source) {
+      if (col_params_->is_source) {
         TF_CHECK_OK(ctx.forward_input_or_allocate_output(
             {0}, 0, tensor_.shape(), &output_tensor_ptr));
       } else {
@@ -665,11 +673,11 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
       }
       CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
       const Tensor* input_tensor_ptr =
-          col_params_.is_source ? &tensor_ : nullptr;
+          col_params_->is_source ? &tensor_ : nullptr;
 
       // Prepare a Broadcaster instance.
       string exec_key =
-          strings::StrCat(col_params_.instance.instance_key, ":0:0");
+          strings::StrCat(col_params_->instance.instance_key, ":0:0");
       HierarchicalTreeBroadcaster* broadcaster =
           new HierarchicalTreeBroadcaster;
       core::ScopedUnref unref(broadcaster);
@@ -694,7 +702,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
     int rank_;
     Tensor tensor_;
     Device* device_;
-    CollectiveParams col_params_;
+    CollectiveParams* col_params_;
     Status status_;
   };  // class DeviceInstance
 
@@ -708,7 +716,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
   std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<DeviceInstance*> instances_;
-  CollectiveParams col_params_;
+  CollectiveParams* col_params_;
   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
   std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
   std::unique_ptr<string> gpu_ring_order_;
@@ -720,33 +728,35 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 };
 
 TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams1Task8GPU) {
-  CollectiveParams cp;
-  PrepColParamsForSubdivPermsTest(&cp, 1, 8);
+  auto* cp = new CollectiveParams();
+  core::ScopedUnref unref(cp);
+  PrepColParamsForSubdivPermsTest(cp, 1, 8);
 
   // source 0 device 0
-  cp.source_rank = 0;
-  cp.default_rank = 0;
-  RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0});
+  cp->source_rank = 0;
+  cp->default_rank = 0;
+  RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0});
 
   // source 2 device 2
-  cp.source_rank = 2;
-  cp.default_rank = 2;
-  RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2});
+  cp->source_rank = 2;
+  cp->default_rank = 2;
+  RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2});
 
   // source 2 device 0
-  cp.source_rank = 2;
-  cp.default_rank = 0;
-  RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2});
+  cp->source_rank = 2;
+  cp->default_rank = 0;
+  RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2});
 }
 
 TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
-  CollectiveParams cp;
-  PrepColParamsForSubdivPermsTest(&cp, 4, 8);
+  auto* cp = new CollectiveParams();
+  core::ScopedUnref unref(cp);
+  PrepColParamsForSubdivPermsTest(cp, 4, 8);
 
   // source 0 device 0
-  cp.source_rank = 0;
-  cp.default_rank = 0;
-  RunSubdivPermsTest(&cp,
+  cp->source_rank = 0;
+  cp->default_rank = 0;
+  RunSubdivPermsTest(cp,
                      {{0, 8, 16, 24},
                       {0, 1, 2, 3, 4, 5, 6, 7},
                       {8, 9, 10, 11, 12, 13, 14, 15},
@@ -755,9 +765,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
                      {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
 
   // source 2 device 0
-  cp.source_rank = 2;
-  cp.default_rank = 0;
-  RunSubdivPermsTest(&cp,
+  cp->source_rank = 2;
+  cp->default_rank = 0;
+  RunSubdivPermsTest(cp,
                      {{2, 8, 16, 24},
                       {0, 1, 2, 3, 4, 5, 6, 7},
                       {8, 9, 10, 11, 12, 13, 14, 15},
@@ -766,9 +776,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
                      {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
 
   // source 9 device 9
-  cp.source_rank = 9;
-  cp.default_rank = 9;
-  RunSubdivPermsTest(&cp,
+  cp->source_rank = 9;
+  cp->default_rank = 9;
+  RunSubdivPermsTest(cp,
                      {{0, 9, 16, 24},
                       {0, 1, 2, 3, 4, 5, 6, 7},
                       {8, 9, 10, 11, 12, 13, 14, 15},
@@ -778,28 +788,29 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
 }
 
 TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
-  CollectiveParams cp;
+  auto* cp = new CollectiveParams();
+  core::ScopedUnref unref(cp);
   int num_tasks = 4;
-  cp.group.device_type = DeviceType("GPU");
-  cp.group.num_tasks = num_tasks;
-  cp.group.group_size = 0;
-  cp.instance.type = BROADCAST_COLLECTIVE;
-  cp.instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
+  cp->group.device_type = DeviceType("GPU");
+  cp->group.num_tasks = num_tasks;
+  cp->group.group_size = 0;
+  cp->instance.type = BROADCAST_COLLECTIVE;
+  cp->instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
   std::vector<int> dev_per_task = {4, 4, 6, 8};
-  for (int ti = 0; ti < cp.group.num_tasks; ti++) {
+  for (int ti = 0; ti < cp->group.num_tasks; ti++) {
     string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
     for (int di = 0; di < dev_per_task[ti]; di++) {
       string dev_name = strings::StrCat(task_name, "/device:GPU:", di);
-      cp.group.task_names.push_back(task_name);
-      cp.group.device_names.push_back(dev_name);
-      cp.group.group_size++;
+      cp->group.task_names.push_back(task_name);
+      cp->group.device_names.push_back(dev_name);
+      cp->group.group_size++;
     }
   }
 
   // source 0 device 0
-  cp.source_rank = 0;
-  cp.default_rank = 0;
-  RunSubdivPermsTest(&cp,
+  cp->source_rank = 0;
+  cp->default_rank = 0;
+  RunSubdivPermsTest(cp,
                      {{0, 4, 8, 14},
                       {0, 1, 2, 3},
                       {4, 5, 6, 7},
@@ -808,9 +819,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
                      {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
 
   // source 2 device 0
-  cp.source_rank = 2;
-  cp.default_rank = 0;
-  RunSubdivPermsTest(&cp,
+  cp->source_rank = 2;
+  cp->default_rank = 0;
+  RunSubdivPermsTest(cp,
                      {{2, 4, 8, 14},
                       {0, 1, 2, 3},
                       {4, 5, 6, 7},
@@ -819,9 +830,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
                      {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
 
   // source 9 device 5
-  cp.source_rank = 9;
-  cp.default_rank = 5;
-  RunSubdivPermsTest(&cp,
+  cp->source_rank = 9;
+  cp->default_rank = 5;
+  RunSubdivPermsTest(cp,
                      {{0, 4, 9, 14},
                       {0, 1, 2, 3},
                       {4, 5, 6, 7},
diff --git a/tensorflow/core/common_runtime/permuter.cc b/tensorflow/core/common_runtime/permuter.cc
index 9aee5e5d5c9..c1dcd20dc06 100644
--- a/tensorflow/core/common_runtime/permuter.cc
+++ b/tensorflow/core/common_runtime/permuter.cc
@@ -54,7 +54,7 @@ Status Permuter::InitializeCollectiveContext(
     std::shared_ptr<CollectiveContext> col_ctx) {
   DCHECK(col_ctx->dev_mgr);
   col_ctx_ = col_ctx;
-  col_params_ = &col_ctx->col_params;
+  col_params_ = col_ctx->col_params;
   return collective_util::InitializeDeviceAndLocality(
       col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
       &col_ctx->device_locality);
diff --git a/tensorflow/core/common_runtime/permuter_test.cc b/tensorflow/core/common_runtime/permuter_test.cc
index 10c527ca573..a5f8add6c30 100644
--- a/tensorflow/core/common_runtime/permuter_test.cc
+++ b/tensorflow/core/common_runtime/permuter_test.cc
@@ -107,12 +107,14 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
 
 class PermuterTest : public ::testing::Test {
  protected:
-  PermuterTest() : device_type_(DEVICE_CPU) {}
+  PermuterTest()
+      : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
 
   ~PermuterTest() override {
     stop_ = true;
     for (auto i : instances_) delete i;
     if (col_exec_) col_exec_->Unref();
+    if (col_params_) col_params_->Unref();
   }
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -170,12 +172,13 @@ class PermuterTest : public ::testing::Test {
     col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
                                            dev_mgr_.get(),
                                            gpu_ring_order_.get(), work_queue_);
-    col_params_.name = "test_collective";
-    col_params_.instance.data_type = dtype;
+    col_params_ = new CollectiveParams();
+    col_params_->name = "test_collective";
+    col_params_->instance.data_type = dtype;
     static const int kInstanceKey = 18;
-    col_params_.instance.instance_key = kInstanceKey;
-    col_params_.group.device_type = device_type;
-    col_params_.instance.type = PERMUTE_COLLECTIVE;
+    col_params_->instance.instance_key = kInstanceKey;
+    col_params_->group.device_type = device_type;
+    col_params_->instance.type = PERMUTE_COLLECTIVE;
 
     // Set up all the fake device contexts.
     for (int wi = 0; wi < num_workers; wi++) {
@@ -187,12 +190,12 @@ class PermuterTest : public ::testing::Test {
         } else {
           dev_name = strings::StrCat(task_name, "/device:CPU:", di);
         }
-        col_params_.group.device_names.push_back(dev_name);
-        col_params_.instance.devices.push_back(dev_name);
+        col_params_->group.device_names.push_back(dev_name);
+        col_params_->instance.devices.push_back(dev_name);
         int default_rank = wi * num_devices_per_worker + di;
         permutation_.push_back(default_rank);
-        col_params_.group.task_names.push_back(task_name);
-        col_params_.task.is_local.push_back(true);
+        col_params_->group.task_names.push_back(task_name);
+        col_params_->task.is_local.push_back(true);
       }
     }
 
@@ -210,13 +213,13 @@ class PermuterTest : public ::testing::Test {
       std::next_permutation(permutation_.begin() + i,
                             permutation_.begin() + i + 2);
     }
-    col_params_.instance.permutation = permutation_;
+    col_params_->instance.permutation = permutation_;
 
     for (int wi = 0; wi < num_workers; wi++) {
       for (int di = 0; di < num_devices_per_worker; di++) {
         int default_rank = wi * num_devices_per_worker + di;
         instances_.push_back(new DeviceInstance(
-            default_rank, col_params_.group.device_names[default_rank],
+            default_rank, col_params_->group.device_names[default_rank],
             device_type, this));
       }
     }
@@ -320,25 +323,30 @@ class PermuterTest : public ::testing::Test {
         : parent_(parent),
           dev_name_(dev_name),
           device_type_(device_type),
-          rank_(rank) {
+          rank_(rank),
+          col_params_(new CollectiveParams()) {
       TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_));
-      col_params_.name = parent_->col_params_.name;
-      col_params_.instance.data_type = parent_->col_params_.instance.data_type;
-      col_params_.instance.instance_key =
-          parent_->col_params_.instance.instance_key;
-      col_params_.group.device_type = parent_->col_params_.group.device_type;
-      col_params_.group.device_names = parent_->col_params_.group.device_names;
-      col_params_.instance.devices = parent_->col_params_.instance.devices;
-      col_params_.instance.permutation =
-          parent->col_params_.instance.permutation;
-      col_params_.group.task_names = parent_->col_params_.group.task_names;
-      col_params_.task.is_local = parent_->col_params_.task.is_local;
-      CHECK_EQ(col_params_.instance.devices.size(),
-               col_params_.group.device_names.size());
+      col_params_->name = parent_->col_params_->name;
+      col_params_->instance.data_type =
+          parent_->col_params_->instance.data_type;
+      col_params_->instance.instance_key =
+          parent_->col_params_->instance.instance_key;
+      col_params_->group.device_type = parent_->col_params_->group.device_type;
+      col_params_->group.device_names =
+          parent_->col_params_->group.device_names;
+      col_params_->instance.devices = parent_->col_params_->instance.devices;
+      col_params_->instance.permutation =
+          parent->col_params_->instance.permutation;
+      col_params_->group.task_names = parent_->col_params_->group.task_names;
+      col_params_->task.is_local = parent_->col_params_->task.is_local;
+      CHECK_EQ(col_params_->instance.devices.size(),
+               col_params_->group.device_names.size());
       // Default rank is order in device_names.
-      col_params_.default_rank = rank;
+      col_params_->default_rank = rank;
     }
 
+    ~DeviceInstance() { col_params_->Unref(); }
+
     void InitTensor(DataType dtype, const TensorShape& shape,
                     const InitFunc& f) {
       tensor_input_ =
@@ -387,7 +395,7 @@ class PermuterTest : public ::testing::Test {
 
       // Prepare a Permuter instance.
       string exec_key =
-          strings::StrCat(col_params_.instance.instance_key, ":0:0");
+          strings::StrCat(col_params_->instance.instance_key, ":0:0");
       Permuter* permuter = new Permuter;
       core::ScopedUnref unref(permuter);
       auto col_ctx = std::make_shared<CollectiveContext>(
@@ -412,7 +420,7 @@ class PermuterTest : public ::testing::Test {
     Tensor tensor_input_;
     Tensor tensor_output_;
     Device* device_;
-    CollectiveParams col_params_;
+    CollectiveParams* col_params_;
     Status status_;
   };  // class DeviceInstance
 
@@ -425,7 +433,7 @@ class PermuterTest : public ::testing::Test {
   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
   std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<DeviceInstance*> instances_;
-  CollectiveParams col_params_;
+  CollectiveParams* col_params_;
   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
   std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
   std::unique_ptr<string> gpu_ring_order_;
diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc
index e664eb90865..a081d2cb730 100644
--- a/tensorflow/core/common_runtime/ring_alg.cc
+++ b/tensorflow/core/common_runtime/ring_alg.cc
@@ -245,7 +245,7 @@ Status RingAlg::InitializeCollectiveContext(
     std::shared_ptr<CollectiveContext> col_ctx) {
   DCHECK(col_ctx->dev_mgr);
   col_ctx_ = col_ctx;
-  col_params_ = &col_ctx->col_params;
+  col_params_ = col_ctx->col_params;
   return collective_util::InitializeDeviceAndLocality(
       col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
       &col_ctx->device_locality);
diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc
index 1f23ee1a8a7..0a6f81a5a2a 100644
--- a/tensorflow/core/common_runtime/ring_gatherer_test.cc
+++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc
@@ -115,7 +115,8 @@ static int64 kStepId = 123;
 
 class RingGathererTest : public ::testing::Test {
  protected:
-  RingGathererTest() : device_type_(DEVICE_CPU) {}
+  RingGathererTest()
+      : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   void InitGPUDevices() {
@@ -132,6 +133,7 @@ class RingGathererTest : public ::testing::Test {
     stop_ = true;
     for (auto i : instances_) delete i;
     if (col_exec_) col_exec_->Unref();
+    if (col_params_) col_params_->Unref();
   }
 
   void Init(int num_workers, int num_devices, DataType dtype,
@@ -180,24 +182,25 @@ class RingGathererTest : public ::testing::Test {
     col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
                                            dev_mgr_.get(),
                                            gpu_ring_order_.get(), work_queue_);
-    col_params_.name = "test_collective";
+    col_params_ = new CollectiveParams();
+    col_params_->name = "test_collective";
     static const int kGroupKey = 5;
-    col_params_.group.group_key = kGroupKey;
-    col_params_.group.device_type = device_type;
-    col_params_.group.group_size = num_workers * num_devices;
+    col_params_->group.group_key = kGroupKey;
+    col_params_->group.device_type = device_type;
+    col_params_->group.group_size = num_workers * num_devices;
     static const int kInstanceKey = 17;
-    col_params_.instance.instance_key = kInstanceKey;
-    col_params_.instance.impl_details.subdiv_offsets.clear();
-    col_params_.instance.type = GATHER_COLLECTIVE;
-    col_params_.instance.impl_details.collective_name = "RingGather";
-    col_params_.instance.data_type = dtype;
-    col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
-    col_params_.subdiv_rank.resize(num_subdivs);
+    col_params_->instance.instance_key = kInstanceKey;
+    col_params_->instance.impl_details.subdiv_offsets.clear();
+    col_params_->instance.type = GATHER_COLLECTIVE;
+    col_params_->instance.impl_details.collective_name = "RingGather";
+    col_params_->instance.data_type = dtype;
+    col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs);
+    col_params_->subdiv_rank.resize(num_subdivs);
     int subdiv_stride = num_devices / num_subdivs;
     for (int sdi = 0; sdi < num_subdivs; ++sdi) {
-      col_params_.instance.impl_details.subdiv_offsets.push_back(sdi *
-                                                                 subdiv_stride);
-      col_params_.subdiv_rank[sdi] = sdi * subdiv_stride;
+      col_params_->instance.impl_details.subdiv_offsets.push_back(
+          sdi * subdiv_stride);
+      col_params_->subdiv_rank[sdi] = sdi * subdiv_stride;
     }
 
     // Set up a local device ring order that's not just 0,1,2...
@@ -225,16 +228,16 @@ class RingGathererTest : public ::testing::Test {
           dev_name =
               strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size());
         }
-        col_params_.group.device_names.push_back(dev_name);
-        col_params_.group.task_names.push_back(task_name);
+        col_params_->group.device_names.push_back(dev_name);
+        col_params_->group.task_names.push_back(task_name);
         // Normally each device would set is_local to its own perspective but
         // this test runs in a single process so is_local is always true.
-        col_params_.task.is_local.push_back(true);
+        col_params_->task.is_local.push_back(true);
         for (int sdi = 0; sdi < num_subdivs; ++sdi) {
           int rotated_di =
-              (di + col_params_.instance.impl_details.subdiv_offsets[sdi]) %
+              (di + col_params_->instance.impl_details.subdiv_offsets[sdi]) %
               num_devices;
-          col_params_.instance.impl_details.subdiv_permutations[sdi].push_back(
+          col_params_->instance.impl_details.subdiv_permutations[sdi].push_back(
               wi * num_devices + local_ring_order[rotated_di]);
         }
       }
@@ -243,7 +246,7 @@ class RingGathererTest : public ::testing::Test {
       for (int di = 0; di < num_devices; ++di) {
         int rank = wi * num_devices + di;
         instances_.push_back(new DeviceInstance(
-            rank, col_params_.group.device_names[rank], device_type_, this));
+            rank, col_params_->group.device_names[rank], device_type_, this));
       }
     }
   }
@@ -387,39 +390,42 @@ class RingGathererTest : public ::testing::Test {
         : parent_(parent),
           dev_name_(dev_name),
           device_type_(device_type),
-          rank_(rank) {
+          rank_(rank),
+          col_params_(new CollectiveParams()) {
       TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_))
           << "Couldn't find device " << dev_name
           << " existing devices: " << parent_->dev_mgr_->DebugString();
-      col_params_.name = parent_->col_params_.name;
-      col_params_.group = parent_->col_params_.group;
-      col_params_.instance = parent->col_params_.instance;
-      col_params_.task.is_local = parent_->col_params_.task.is_local;
-      col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
+      col_params_->name = parent_->col_params_->name;
+      col_params_->group = parent_->col_params_->group;
+      col_params_->instance = parent->col_params_->instance;
+      col_params_->task.is_local = parent_->col_params_->task.is_local;
+      col_params_->subdiv_rank = parent_->col_params_->subdiv_rank;
 
-      int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
-      int group_size = col_params_.group.group_size;
+      int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
+      int group_size = col_params_->group.group_size;
       CHECK_EQ(group_size,
-               static_cast<int>(col_params_.group.device_names.size()));
+               static_cast<int>(col_params_->group.device_names.size()));
       // Id of this device is at rank position in first subdiv perm.
       int my_device_id =
-          col_params_.instance.impl_details.subdiv_permutations[0][rank];
-      col_params_.default_rank = my_device_id;
+          col_params_->instance.impl_details.subdiv_permutations[0][rank];
+      col_params_->default_rank = my_device_id;
       // Set rank for all other subdivs by finding that device_id.
       for (int sdi = 0; sdi < num_subdivs; ++sdi) {
-        for (int r = 0; r < static_cast<int>(col_params_.instance.impl_details
+        for (int r = 0; r < static_cast<int>(col_params_->instance.impl_details
                                                  .subdiv_permutations[sdi]
                                                  .size());
              ++r) {
           if (my_device_id ==
-              col_params_.instance.impl_details.subdiv_permutations[sdi][r]) {
-            col_params_.subdiv_rank[sdi] = r;
+              col_params_->instance.impl_details.subdiv_permutations[sdi][r]) {
+            col_params_->subdiv_rank[sdi] = r;
             break;
           }
         }
       }
     }
 
+    ~DeviceInstance() { col_params_->Unref(); }
+
     void InitTensor(DataType dtype, const TensorShape& shape,
                     const std::function<void(Tensor*)>& init_f) {
       input_tensor_ =
@@ -464,7 +470,7 @@ class RingGathererTest : public ::testing::Test {
       AllocatorAttributes generic_alloc_attr;
       op_params.output_attr_array = &generic_alloc_attr;
       std::unique_ptr<OpKernel> op = parent_->GetCollectiveGather(
-          col_params_, &input_tensor_, DEVICE_CPU, device_);
+          *col_params_, &input_tensor_, DEVICE_CPU, device_);
       op_params.op_kernel = op.get();
       OpKernelContext ctx(&op_params, 1);
 
@@ -478,7 +484,7 @@ class RingGathererTest : public ::testing::Test {
       CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
       // Prepare a RingGatherer instance.
       string exec_key =
-          strings::StrCat(col_params_.instance.instance_key, ":0:0");
+          strings::StrCat(col_params_->instance.instance_key, ":0:0");
       RingGatherer* gatherer = new RingGatherer;
       core::ScopedUnref unref(gatherer);
       auto col_ctx = std::make_shared<CollectiveContext>(
@@ -507,7 +513,7 @@ class RingGathererTest : public ::testing::Test {
     Tensor input_tensor_;
     Tensor output_tensor_;
     Device* device_;
-    CollectiveParams col_params_;
+    CollectiveParams* col_params_;
     std::unique_ptr<CollectiveAdapter> ca_;
     std::unique_ptr<OpKernelContext> ctx_;
     Status status_;
@@ -521,7 +527,7 @@ class RingGathererTest : public ::testing::Test {
   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
   std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<DeviceInstance*> instances_;
-  CollectiveParams col_params_;
+  CollectiveParams* col_params_;
   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
   std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
   std::unique_ptr<string> gpu_ring_order_;
@@ -530,28 +536,28 @@ class RingGathererTest : public ::testing::Test {
   CancellationManager cancellation_manager_;
 };
 
-CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
-                                       const int num_tasks) {
-  CollectiveParams cp;
+CollectiveParams* SetUpCollectiveParams(const int num_devs_per_task,
+                                        const int num_tasks) {
+  auto* cp = new CollectiveParams();
   const int kNumDevs = num_devs_per_task * num_tasks;
-  cp.group.group_key = 1;
-  cp.group.group_size = kNumDevs;
-  cp.group.device_type = DeviceType("GPU");
-  cp.group.num_tasks = num_tasks;
-  cp.instance.instance_key = 3;
-  cp.instance.type = GATHER_COLLECTIVE;
-  cp.instance.data_type = DataType(DT_FLOAT);
-  cp.instance.shape = TensorShape({kNumDevs * kNumDevs});
-  cp.instance.impl_details.collective_name = "RingGather";
-  cp.instance.impl_details.subdiv_offsets.push_back(0);
-  cp.is_source = false;
+  cp->group.group_key = 1;
+  cp->group.group_size = kNumDevs;
+  cp->group.device_type = DeviceType("GPU");
+  cp->group.num_tasks = num_tasks;
+  cp->instance.instance_key = 3;
+  cp->instance.type = GATHER_COLLECTIVE;
+  cp->instance.data_type = DataType(DT_FLOAT);
+  cp->instance.shape = TensorShape({kNumDevs * kNumDevs});
+  cp->instance.impl_details.collective_name = "RingGather";
+  cp->instance.impl_details.subdiv_offsets.push_back(0);
+  cp->is_source = false;
   for (int i = 0; i < kNumDevs; ++i) {
     int task_id = i / num_devs_per_task;
     int dev_id = i % num_devs_per_task;
     string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
     string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
-    cp.group.task_names.push_back(task_name);
-    cp.group.device_names.push_back(device_name);
+    cp->group.task_names.push_back(task_name);
+    cp->group.device_names.push_back(device_name);
   }
   return cp;
 }
@@ -559,23 +565,24 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
 TEST_F(RingGathererTest, InitializeParams) {
   const int kNumDevsPerTask = 8;
   const int kNumTasks = 3;
-  CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+  CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+  core::ScopedUnref unref(cp);
 
-  cp.default_rank = 0;
-  cp.instance.impl_details.subdiv_offsets = {};
-  RunSubdivPermsTest(&cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
-                            12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
+  cp->default_rank = 0;
+  cp->instance.impl_details.subdiv_offsets = {};
+  RunSubdivPermsTest(cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+                           12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
                      {0});
 
-  cp.instance.impl_details.subdiv_offsets = {0};
-  RunSubdivPermsTest(&cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
-                            12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
+  cp->instance.impl_details.subdiv_offsets = {0};
+  RunSubdivPermsTest(cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+                           12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
                      {0});
 
-  cp.default_rank = 3;
-  cp.instance.impl_details.subdiv_offsets = {};
-  RunSubdivPermsTest(&cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
-                            12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
+  cp->default_rank = 3;
+  cp->instance.impl_details.subdiv_offsets = {};
+  RunSubdivPermsTest(cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+                           12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
                      {3});
 }
 
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index 3b153e4ca1d..89f8605ae25 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -138,7 +138,8 @@ static int64 kStepId = 123;
 
 class RingReducerTest : public ::testing::Test {
  protected:
-  RingReducerTest() : device_type_(DEVICE_CPU) {}
+  RingReducerTest()
+      : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   void InitGPUDevices() {
@@ -155,6 +156,7 @@ class RingReducerTest : public ::testing::Test {
     stop_ = true;
     for (auto i : instances_) delete i;
     if (col_exec_) col_exec_->Unref();
+    if (col_params_) col_params_->Unref();
   }
 
   void Init(int num_workers, int num_devices, DataType dtype,
@@ -203,24 +205,25 @@ class RingReducerTest : public ::testing::Test {
     col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
                                            dev_mgr_.get(),
                                            gpu_ring_order_.get(), work_queue_);
-    col_params_.name = "test_collective";
+    col_params_ = new CollectiveParams();
+    col_params_->name = "test_collective";
     static const int kGroupKey = 5;
-    col_params_.group.group_key = kGroupKey;
-    col_params_.group.device_type = device_type;
-    col_params_.group.group_size = num_workers * num_devices;
+    col_params_->group.group_key = kGroupKey;
+    col_params_->group.device_type = device_type;
+    col_params_->group.group_size = num_workers * num_devices;
     static const int kInstanceKey = 17;
-    col_params_.instance.instance_key = kInstanceKey;
-    col_params_.instance.impl_details.subdiv_offsets.clear();
-    col_params_.instance.type = REDUCTION_COLLECTIVE;
-    col_params_.instance.impl_details.collective_name = "RingReduce";
-    col_params_.instance.data_type = dtype;
-    col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
-    col_params_.subdiv_rank.resize(num_subdivs);
+    col_params_->instance.instance_key = kInstanceKey;
+    col_params_->instance.impl_details.subdiv_offsets.clear();
+    col_params_->instance.type = REDUCTION_COLLECTIVE;
+    col_params_->instance.impl_details.collective_name = "RingReduce";
+    col_params_->instance.data_type = dtype;
+    col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs);
+    col_params_->subdiv_rank.resize(num_subdivs);
     int subdiv_stride = num_devices / num_subdivs;
     for (int sdi = 0; sdi < num_subdivs; ++sdi) {
-      col_params_.instance.impl_details.subdiv_offsets.push_back(sdi *
-                                                                 subdiv_stride);
-      col_params_.subdiv_rank[sdi] = sdi * subdiv_stride;
+      col_params_->instance.impl_details.subdiv_offsets.push_back(
+          sdi * subdiv_stride);
+      col_params_->subdiv_rank[sdi] = sdi * subdiv_stride;
     }
 
     // Set up a local device ring order that's not just 0,1,2...
@@ -242,23 +245,23 @@ class RingReducerTest : public ::testing::Test {
     // Set up all of the fake device contexts.
     for (int wi = 0; wi < num_workers; ++wi) {
       string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
-      col_params_.group.num_devices_per_task[task_name] = num_devices;
+      col_params_->group.num_devices_per_task[task_name] = num_devices;
       for (int di = 0; di < num_devices; ++di) {
         string dev_name = strings::StrCat(task_name, "/cpu:", di);
         if (device_type == DEVICE_GPU) {
           dev_name =
               strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size());
         }
-        col_params_.group.device_names.push_back(dev_name);
-        col_params_.group.task_names.push_back(task_name);
+        col_params_->group.device_names.push_back(dev_name);
+        col_params_->group.task_names.push_back(task_name);
         // Normally each device would set is_local to its own perspective but
         // this test runs in a single process so is_local is always true.
-        col_params_.task.is_local.push_back(true);
+        col_params_->task.is_local.push_back(true);
         for (int sdi = 0; sdi < num_subdivs; ++sdi) {
           int rotated_di =
-              (di + col_params_.instance.impl_details.subdiv_offsets[sdi]) %
+              (di + col_params_->instance.impl_details.subdiv_offsets[sdi]) %
               num_devices;
-          col_params_.instance.impl_details.subdiv_permutations[sdi].push_back(
+          col_params_->instance.impl_details.subdiv_permutations[sdi].push_back(
               wi * num_devices + local_ring_order[rotated_di]);
         }
       }
@@ -267,7 +270,7 @@ class RingReducerTest : public ::testing::Test {
       for (int di = 0; di < num_devices; ++di) {
         int rank = wi * num_devices + di;
         instances_.push_back(new DeviceInstance(
-            rank, col_params_.group.device_names[rank], device_type_, this));
+            rank, col_params_->group.device_names[rank], device_type_, this));
       }
     }
   }
@@ -413,39 +416,42 @@ class RingReducerTest : public ::testing::Test {
         : parent_(parent),
           dev_name_(dev_name),
           device_type_(device_type),
-          rank_(rank) {
+          rank_(rank),
+          col_params_(new CollectiveParams()) {
       TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_))
           << "Couldn't find device " << dev_name
           << " existing devices: " << parent_->dev_mgr_->DebugString();
-      col_params_.name = parent_->col_params_.name;
-      col_params_.group = parent_->col_params_.group;
-      col_params_.instance = parent->col_params_.instance;
-      col_params_.task.is_local = parent_->col_params_.task.is_local;
-      col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
+      col_params_->name = parent_->col_params_->name;
+      col_params_->group = parent_->col_params_->group;
+      col_params_->instance = parent->col_params_->instance;
+      col_params_->task.is_local = parent_->col_params_->task.is_local;
+      col_params_->subdiv_rank = parent_->col_params_->subdiv_rank;
 
-      int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
-      int group_size = col_params_.group.group_size;
+      int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
+      int group_size = col_params_->group.group_size;
       CHECK_EQ(group_size,
-               static_cast<int>(col_params_.group.device_names.size()));
+               static_cast<int>(col_params_->group.device_names.size()));
       // Id of this device is at rank position in first subdiv perm.
       int my_device_id =
-          col_params_.instance.impl_details.subdiv_permutations[0][rank];
-      col_params_.default_rank = my_device_id;
+          col_params_->instance.impl_details.subdiv_permutations[0][rank];
+      col_params_->default_rank = my_device_id;
       // Set rank for all other subdivs by finding that device_id.
       for (int sdi = 0; sdi < num_subdivs; ++sdi) {
-        for (int r = 0; r < static_cast<int>(col_params_.instance.impl_details
+        for (int r = 0; r < static_cast<int>(col_params_->instance.impl_details
                                                  .subdiv_permutations[sdi]
                                                  .size());
              ++r) {
           if (my_device_id ==
-              col_params_.instance.impl_details.subdiv_permutations[sdi][r]) {
-            col_params_.subdiv_rank[sdi] = r;
+              col_params_->instance.impl_details.subdiv_permutations[sdi][r]) {
+            col_params_->subdiv_rank[sdi] = r;
             break;
           }
         }
       }
     }
 
+    ~DeviceInstance() { col_params_->Unref(); }
+
     void InitTensor(DataType dtype, const TensorShape& shape,
                     const std::function<void(Tensor*)>& init_f) {
       tensor_ =
@@ -466,10 +472,12 @@ class RingReducerTest : public ::testing::Test {
     }
 
     void DoReduce() {
-      merge_op_ = GetAdd(col_params_.instance.data_type, device_type_, device_);
-      final_op_ = GetDiv(col_params_.instance.data_type, device_type_, device_);
-      col_params_.merge_op = merge_op_.get();
-      col_params_.final_op = final_op_.get();
+      merge_op_ =
+          GetAdd(col_params_->instance.data_type, device_type_, device_);
+      final_op_ =
+          GetDiv(col_params_->instance.data_type, device_type_, device_);
+      col_params_->merge_op = merge_op_.get();
+      col_params_->final_op = final_op_.get();
 
       // Prepare an OpKernelContext.
       OpKernelContext::Params op_params;
@@ -496,7 +504,7 @@ class RingReducerTest : public ::testing::Test {
       AllocatorAttributes generic_alloc_attr;
       op_params.output_attr_array = &generic_alloc_attr;
       std::unique_ptr<OpKernel> op = parent_->GetCollectiveReduce(
-          col_params_, &tensor_, DEVICE_CPU, device_);
+          *col_params_, &tensor_, DEVICE_CPU, device_);
       op_params.op_kernel = op.get();
       OpKernelContext ctx(&op_params, 1);
 
@@ -509,7 +517,7 @@ class RingReducerTest : public ::testing::Test {
 
       // Prepare a RingReducer instance.
       string exec_key =
-          strings::StrCat(col_params_.instance.instance_key, ":0:0");
+          strings::StrCat(col_params_->instance.instance_key, ":0:0");
       RingReducer* reducer = new RingReducer;
       core::ScopedUnref unref(reducer);
       auto col_ctx = std::make_shared<CollectiveContext>(
@@ -535,7 +543,7 @@ class RingReducerTest : public ::testing::Test {
     int rank_;
     Tensor tensor_;
     Device* device_;
-    CollectiveParams col_params_;
+    CollectiveParams* col_params_;
     std::unique_ptr<OpKernel> merge_op_;
     std::unique_ptr<OpKernel> final_op_;
     std::unique_ptr<CollectiveAdapter> ca_;
@@ -551,7 +559,7 @@ class RingReducerTest : public ::testing::Test {
   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
   std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<DeviceInstance*> instances_;
-  CollectiveParams col_params_;
+  CollectiveParams* col_params_;
   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
   std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
   std::unique_ptr<string> gpu_ring_order_;
@@ -560,28 +568,28 @@ class RingReducerTest : public ::testing::Test {
   CancellationManager cancellation_manager_;
 };
 
-CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
-                                       const int num_tasks) {
-  CollectiveParams cp;
+CollectiveParams* SetUpCollectiveParams(const int num_devs_per_task,
+                                        const int num_tasks) {
+  auto cp = new CollectiveParams();
   const int kNumDevs = num_devs_per_task * num_tasks;
-  cp.group.group_key = 1;
-  cp.group.group_size = kNumDevs;
-  cp.group.device_type = DeviceType("GPU");
-  cp.group.num_tasks = num_tasks;
-  cp.instance.instance_key = 3;
-  cp.instance.type = REDUCTION_COLLECTIVE;
-  cp.instance.data_type = DataType(DT_FLOAT);
-  cp.instance.shape = TensorShape({kNumDevs});
-  cp.instance.impl_details.collective_name = "RingReduce";
-  cp.instance.impl_details.subdiv_offsets.push_back(0);
-  cp.is_source = false;
+  cp->group.group_key = 1;
+  cp->group.group_size = kNumDevs;
+  cp->group.device_type = DeviceType("GPU");
+  cp->group.num_tasks = num_tasks;
+  cp->instance.instance_key = 3;
+  cp->instance.type = REDUCTION_COLLECTIVE;
+  cp->instance.data_type = DataType(DT_FLOAT);
+  cp->instance.shape = TensorShape({kNumDevs});
+  cp->instance.impl_details.collective_name = "RingReduce";
+  cp->instance.impl_details.subdiv_offsets.push_back(0);
+  cp->is_source = false;
   for (int i = 0; i < kNumDevs; ++i) {
     int task_id = i / num_devs_per_task;
     int dev_id = i % num_devs_per_task;
     string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
     string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
-    cp.group.task_names.push_back(task_name);
-    cp.group.device_names.push_back(device_name);
+    cp->group.task_names.push_back(task_name);
+    cp->group.device_names.push_back(device_name);
   }
   return cp;
 }
@@ -589,28 +597,29 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
 TEST_F(RingReducerTest, InitializeParams) {
   const int kNumDevsPerTask = 8;
   const int kNumTasks = 3;
-  CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+  CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+  core::ScopedUnref unref(cp);
 
-  cp.default_rank = 0;
-  cp.instance.impl_details.subdiv_offsets = {0, 4};
-  RunSubdivPermsTest(&cp,
+  cp->default_rank = 0;
+  cp->instance.impl_details.subdiv_offsets = {0, 4};
+  RunSubdivPermsTest(cp,
                      {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
                        12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
                       {4, 5, 6,  7,  0,  1,  2,  3,  12, 13, 14, 15,
                        8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}},
                      {0, 4});
 
-  cp.instance.impl_details.subdiv_offsets = {0, -4};
-  RunSubdivPermsTest(&cp,
+  cp->instance.impl_details.subdiv_offsets = {0, -4};
+  RunSubdivPermsTest(cp,
                      {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
                        12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
                       {3,  2,  1,  0,  7,  6,  5,  4,  11, 10, 9,  8,
                        15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
                      {0, 3});
 
-  cp.default_rank = 3;
-  cp.instance.impl_details.subdiv_offsets = {3, -3};
-  RunSubdivPermsTest(&cp,
+  cp->default_rank = 3;
+  cp->instance.impl_details.subdiv_offsets = {3, -3};
+  RunSubdivPermsTest(cp,
                      {{3,  4, 5, 6,  7,  0,  1,  2,  11, 12, 13, 14,
                        15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18},
                       {4, 3,  2,  1,  0,  7,  6,  5,  12, 11, 10, 9,
@@ -622,13 +631,14 @@ TEST_F(RingReducerTest, AutomaticSubdivs) {
   const int kNumDevsPerTask = 8;
   const int kNumTasks = 3;
   const int kNumDevs = kNumDevsPerTask * kNumTasks;
-  CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+  CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+  core::ScopedUnref unref(cp);
 
   // Test automatic generation of subdiv offsets.
-  cp.default_rank = 0;
-  cp.instance.impl_details.subdiv_offsets.clear();
-  RunSubdivPermsTest(&cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
-                            12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
+  cp->default_rank = 0;
+  cp->instance.impl_details.subdiv_offsets.clear();
+  RunSubdivPermsTest(cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+                           12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
                      {0});
 
   // Set shape so that with 2 subdivs chunk_size is 3 MiB.  This should cause 2
@@ -638,11 +648,11 @@ TEST_F(RingReducerTest, AutomaticSubdivs) {
     int num_chunks = kNumDevs * num_subdivs;
     size_t chunk_size = 3 * 1048576;  // 3 MB
     size_t tensor_size = chunk_size * num_chunks;
-    cp.instance.shape =
+    cp->instance.shape =
         TensorShape({static_cast<int64>(tensor_size / DataTypeSize(DT_FLOAT))});
   }
-  cp.instance.impl_details.subdiv_offsets.clear();
-  RunSubdivPermsTest(&cp,
+  cp->instance.impl_details.subdiv_offsets.clear();
+  RunSubdivPermsTest(cp,
                      {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
                        12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
                       {3,  2,  1,  0,  7,  6,  5,  4,  11, 10, 9,  8,
@@ -653,12 +663,13 @@ TEST_F(RingReducerTest, AutomaticSubdivs) {
 TEST_F(RingReducerTest, AutomaticSubdivUpperBound) {
   const int kNumDevsPerTask = 1;
   const int kNumTasks = 4;
-  CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+  CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+  core::ScopedUnref unref(cp);
 
-  cp.default_rank = 0;
-  cp.instance.impl_details.subdiv_offsets.clear();
-  cp.instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
-  RunSubdivPermsTest(&cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
+  cp->default_rank = 0;
+  cp->instance.impl_details.subdiv_offsets.clear();
+  cp->instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
+  RunSubdivPermsTest(cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
 }
 
 // TODO(b/113171733): change to use TEST_P.
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
index 1f380dab6f8..c5d846e1b57 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -137,13 +137,14 @@ void CollectiveParamResolverDistributed::CompleteGroupAsync(
         "running the same version of Tensorflow on all workers."));
     return;
   }
-  CollectiveParams cp;
-  cp.group.group_key = request->group_key();
-  cp.group.group_size = request->group_size();
-  cp.group.device_type = DeviceType(request->device_type());
-  cp.instance.type = CollectiveType(request->collective_type());
+  auto* cp = new CollectiveParams();
+  core::ScopedUnref unref(cp);
+  cp->group.group_key = request->group_key();
+  cp->group.group_size = request->group_size();
+  cp->group.device_type = DeviceType(request->device_type());
+  cp->instance.type = CollectiveType(request->collective_type());
   CompleteGroupDistributed(
-      request->device_attributes(), &cp, cancel_mgr,
+      request->device_attributes(), cp, cancel_mgr,
       [response, done](const Status& s, const GroupRec* gr) {
         if (s.ok()) {
           mutex_lock l(gr->mu);
@@ -196,7 +197,7 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync(
   }
   StatusCallback done_and_cleanup = [cp, done](const Status& s) {
     done(s);
-    delete cp;
+    cp->Unref();
   };
   CompleteInstanceDistributed(
       request->device(), gr, cp, cancel_mgr,
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
index 1c62b17fe54..8c9f107b9dc 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
@@ -127,6 +127,13 @@ class FakeCache : public TestWorkerCache {
 };
 
 class DeviceResDistTest : public ::testing::Test {
+ public:
+  ~DeviceResDistTest() override {
+    for (auto& name_param : cp_) {
+      name_param.second->Unref();
+    }
+  }
+
  protected:
   void DefineWorkers(int num_workers, int num_devices,
                      const string& device_type, bool nccl) {
@@ -181,20 +188,20 @@ class DeviceResDistTest : public ::testing::Test {
     }
   }
 
-  CollectiveParams CreateCollectiveParams(int num_workers, int num_devices,
-                                          const string& device_type) {
+  CollectiveParams* CreateCollectiveParams(int num_workers, int num_devices,
+                                           const string& device_type) {
     const int kGroupKey = 5;
     const int kInstanceKey = 3;
-    CollectiveParams cp;
-    cp.group.group_key = kGroupKey;
-    cp.group.group_size = num_workers * num_devices;
-    cp.group.device_type = DeviceType(device_type);
-    cp.group.num_tasks = num_workers;
-    cp.instance.instance_key = kInstanceKey;
-    cp.instance.type = REDUCTION_COLLECTIVE;
-    cp.instance.data_type = DT_FLOAT;
-    cp.instance.shape = TensorShape({64});
-    cp.instance.impl_details.subdiv_offsets.push_back(0);
+    auto* cp = new CollectiveParams();
+    cp->group.group_key = kGroupKey;
+    cp->group.group_size = num_workers * num_devices;
+    cp->group.device_type = DeviceType(device_type);
+    cp->group.num_tasks = num_workers;
+    cp->instance.instance_key = kInstanceKey;
+    cp->instance.type = REDUCTION_COLLECTIVE;
+    cp->instance.data_type = DT_FLOAT;
+    cp->instance.shape = TensorShape({64});
+    cp->instance.impl_details.subdiv_offsets.push_back(0);
     return cp;
   }
 
@@ -217,7 +224,7 @@ class DeviceResDistTest : public ::testing::Test {
                     int group_size) {
     Device* device = nullptr;
     TF_CHECK_OK(device_mgrs_[task_name]->LookupDevice(device_name, &device));
-    CollectiveParams* cp = &cp_[device_name];
+    CollectiveParams* cp = cp_[device_name];
     CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name].get();
     CHECK(cp_res);
     cp_res->CompleteParamsAsync(
@@ -252,19 +259,19 @@ class DeviceResDistTest : public ::testing::Test {
         string device_name = strings::StrCat(task_name, "/device:CPU:", di);
         int idx = wi * num_devices + di;
         TF_ASSERT_OK(status_[device_name]);
-        EXPECT_EQ(cp_[device_name].default_rank, idx);
-        EXPECT_EQ(cp_[device_name].group.device_names.size(), dev_count);
-        EXPECT_EQ(cp_[device_name].group.device_names[idx], device_name);
-        EXPECT_EQ(cp_[device_name].group.task_names[idx], task_name);
-        ValidateDeviceResolver(cp_[device_name], task_name);
+        EXPECT_EQ(cp_[device_name]->default_rank, idx);
+        EXPECT_EQ(cp_[device_name]->group.device_names.size(), dev_count);
+        EXPECT_EQ(cp_[device_name]->group.device_names[idx], device_name);
+        EXPECT_EQ(cp_[device_name]->group.task_names[idx], task_name);
+        ValidateDeviceResolver(*cp_[device_name], task_name);
         if (idx > 0) {
-          EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key,
-                    cp_[device_name].group.runtime_details.communicator_key);
+          EXPECT_EQ(cp_[dev0]->group.runtime_details.communicator_key,
+                    cp_[device_name]->group.runtime_details.communicator_key);
           for (int i = 0; i < dev_count; ++i) {
-            EXPECT_EQ(cp_[dev0].group.device_names[i],
-                      cp_[device_name].group.device_names[i]);
-            EXPECT_EQ(cp_[dev0].group.task_names[i],
-                      cp_[device_name].group.task_names[i]);
+            EXPECT_EQ(cp_[dev0]->group.device_names[i],
+                      cp_[device_name]->group.device_names[i]);
+            EXPECT_EQ(cp_[dev0]->group.task_names[i],
+                      cp_[device_name]->group.task_names[i]);
           }
         }
       }
@@ -287,6 +294,9 @@ class DeviceResDistTest : public ::testing::Test {
     for (int i = 0; i < num_devices; ++i) {
       string device_name =
           strings::StrCat(worker_name, "/device:", device_type, ":", i);
+      if (cp_.find(device_name) != cp_.end()) {
+        cp_[device_name]->Unref();
+      }
       cp_[device_name] =
           CreateCollectiveParams(num_workers, num_devices, device_type);
       status_.erase(device_name);
@@ -305,7 +315,7 @@ class DeviceResDistTest : public ::testing::Test {
   absl::flat_hash_map<string, std::vector<string>> dev_by_task_;
   absl::flat_hash_map<string, std::unique_ptr<FakeWorker>> workers_;
   // Below are keyed by device names;
-  absl::flat_hash_map<string, CollectiveParams> cp_;
+  absl::flat_hash_map<string, CollectiveParams*> cp_;
   absl::flat_hash_map<string, Status> status_;
   mutex mu_;
   int num_done_ TF_GUARDED_BY(mu_);
diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc
index 36e26ba9fe9..b1126471b5c 100644
--- a/tensorflow/core/framework/collective.cc
+++ b/tensorflow/core/framework/collective.cc
@@ -169,7 +169,7 @@ string CollectiveParams::ToString() const {
 CollectiveContext::CollectiveContext(
     CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
     const DeviceMgr* dev_mgr, OpKernelContext* ctx,
-    OpKernelContext::Params* op_params, const CollectiveParams& col_params,
+    OpKernelContext::Params* op_params, const CollectiveParams* col_params,
     const string& exec_key, int64 step_id, const Tensor* input, Tensor* output)
     : col_exec(col_exec),
       nccl_communicator(nccl_communicator),
@@ -182,7 +182,7 @@ CollectiveContext::CollectiveContext(
       input(input),
       output(output),
       device(nullptr),
-      device_name(col_params.group.device_names[col_params.default_rank]) {}
+      device_name(col_params->group.device_names[col_params->default_rank]) {}
 
 /*static*/
 int64 CollectiveExecutor::kInvalidId = -1;
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index cd4c28e1d2f..1b5b8f7789b 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -132,7 +132,7 @@ struct CollTaskParams {
 };
 
 // Unique to a single CollectiveOp node.
-struct CollectiveParams {
+struct CollectiveParams : public core::RefCounted {
   CollGroupParams group;
   CollInstanceParams instance;
   CollTaskParams task;
@@ -298,7 +298,7 @@ class CollectiveExecutor : public core::RefCounted {
   virtual void StartAbort(const Status& s) {}
 
   virtual void ExecuteAsync(OpKernelContext* ctx,
-                            const CollectiveParams& col_params,
+                            const CollectiveParams* col_params,
                             const string& exec_key, StatusCallback done) {
     done(errors::Internal(
         "A collective Op has been called in a context in which "
@@ -367,7 +367,7 @@ struct CollectiveContext {
   const DeviceMgr* dev_mgr;                      // Not owned
   OpKernelContext* op_ctx;                       // Not owned
   OpKernelContext::Params* op_params;            // Not owned
-  const CollectiveParams& col_params;
+  const CollectiveParams* col_params;            // Not owned
   const string exec_key;
   const int64 step_id;
   const Tensor* input;  // Not owned
@@ -380,7 +380,7 @@ struct CollectiveContext {
                     NcclCommunicatorInterface* nccl_communicator,
                     const DeviceMgr* dev_mgr, OpKernelContext* ctx,
                     OpKernelContext::Params* op_params,
-                    const CollectiveParams& col_params, const string& exec_key,
+                    const CollectiveParams* col_params, const string& exec_key,
                     int64 step_id, const Tensor* input, Tensor* output);
 };
 
diff --git a/tensorflow/core/kernels/collective_nccl.cc b/tensorflow/core/kernels/collective_nccl.cc
index 44e0b07e9ad..04c6e8a337b 100644
--- a/tensorflow/core/kernels/collective_nccl.cc
+++ b/tensorflow/core/kernels/collective_nccl.cc
@@ -61,7 +61,7 @@ Status NcclBase::InitializeCollectiveParams(CollectiveParams* col_params) {
 Status NcclBase::InitializeCollectiveContext(
     std::shared_ptr<CollectiveContext> col_ctx) {
   col_ctx_ = col_ctx;
-  col_params_ = &col_ctx->col_params;
+  col_params_ = col_ctx->col_params;
   return collective_util::InitializeDeviceAndLocality(
       col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
       &col_ctx->device_locality);
diff --git a/tensorflow/core/kernels/collective_nccl_reducer.cc b/tensorflow/core/kernels/collective_nccl_reducer.cc
index 777c5fc8fc7..6aeec00c1da 100644
--- a/tensorflow/core/kernels/collective_nccl_reducer.cc
+++ b/tensorflow/core/kernels/collective_nccl_reducer.cc
@@ -88,6 +88,9 @@ void NcclReducer::Run(StatusCallback done) {
   } else {
     done_callback = std::move(done);
   }
+  // Hold a ref to col_params for the rest of this function.
+  col_params_->Ref();
+  core::ScopedUnref unref(col_params_);
   col_ctx_->nccl_communicator->Enqueue(col_ctx_, std::move(done_callback));
 
   // If no final_op, then this OpKernel is non-blocking.
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index f8f18751d86..f9516fd8c13 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -53,7 +53,9 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
 class CollectiveOpV1Kernel : public AsyncOpKernel {
  public:
   explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
-      : AsyncOpKernel(c), name_(name()) {}
+      : AsyncOpKernel(c), name_(name()), col_params_(new CollectiveParams()) {}
+
+  ~CollectiveOpV1Kernel() override { col_params_->Unref(); }
 
   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
     CollectiveExecutor* col_exec = c->collective_executor();
@@ -88,28 +90,31 @@ class CollectiveOpV1Kernel : public AsyncOpKernel {
   // A string encoding instance, frame and iter to be handed off to
   // the implementation for use in generating RecvBuf keys.
   string GetCollectiveKey(OpKernelContext* c) {
-    return CollectiveKey(c, col_params_.group.group_key,
-                         col_params_.instance.instance_key);
+    return CollectiveKey(c, col_params_->group.group_key,
+                         col_params_->instance.instance_key);
   }
 
   // Returns false if calling invocation of ComputeAsync should return
   // immediately.
   bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
                              const DoneCallback& done) {
-    if (col_params_.group.group_size > col_params_.group.device_names.size()) {
+    if (col_params_->group.group_size >
+        col_params_->group.device_names.size()) {
       // This is the first invocation: Finish initializing col_params_.
       // Schedule the `CompleteParamsAsync` call on a work queue that can handle
       // blocking work because it's not guaranteed that this call cannot block.
-      c->collective_executor()->RunClosure([this, c, done, col_exec]() {
+      col_params_->Ref();
+      c->collective_executor()->RunClosure([this, c, col_exec, done]() {
         VLOG(1) << "CollectiveOpKernel CompleteParams for collective "
-                << col_params_.name << " device " << c->device()->name()
-                << " group " << col_params_.group.group_key << " instance "
-                << col_params_.instance.instance_key;
+                << col_params_->name << " device " << c->device()->name()
+                << " group " << col_params_->group.group_key << " instance "
+                << col_params_->instance.instance_key;
         col_exec->CompleteParamsAsync(
-            c->device()->attributes(), &col_params_, c->cancellation_manager(),
+            c->device()->attributes(), col_params_, c->cancellation_manager(),
             [this, c, done](const Status& s) {
+              core::ScopedUnref unref(col_params_);
               if (s.ok()) {
-                col_params_.instance.impl_details.dependencies = dependencies_;
+                col_params_->instance.impl_details.dependencies = dependencies_;
                 ComputeAsync(c, done);
               } else {
                 c->SetStatus(s);
@@ -128,7 +133,7 @@ class CollectiveOpV1Kernel : public AsyncOpKernel {
                                 DoneCallback done) = 0;
 
   string name_;
-  CollectiveParams col_params_;
+  CollectiveParams* col_params_;
   std::vector<int32> dependencies_;
 };
 
@@ -136,25 +141,25 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
  public:
   explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
       : CollectiveOpV1Kernel(c) {
-    col_params_.instance.type = GATHER_COLLECTIVE;
-    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+    col_params_->instance.type = GATHER_COLLECTIVE;
+    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
     OP_REQUIRES(
-        c, col_params_.group.group_size > 0,
+        c, col_params_->group.group_size > 0,
         errors::InvalidArgument("group_size must be positive integer but got ",
-                                col_params_.group.group_size));
-    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+                                col_params_->group.group_size));
+    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
     OP_REQUIRES_OK(
-        c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
-    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+        c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
+    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
     OP_REQUIRES_OK(
         c, c->GetAttr("communication_hint",
-                      &col_params_.instance.impl_details.communication_hint));
+                      &col_params_->instance.impl_details.communication_hint));
     OP_REQUIRES_OK(
         c, c->GetAttr("timeout_seconds",
-                      &col_params_.instance.impl_details.timeout_seconds));
+                      &col_params_->instance.impl_details.timeout_seconds));
     const NodeDef& real_node = c->def();
-    col_params_.name = strings::StrCat(real_node.name(), ": Gather");
-    col_params_.group.device_type = c->device_type();
+    col_params_->name = strings::StrCat(real_node.name(), ": Gather");
+    col_params_->group.device_type = c->device_type();
   }
 
  protected:
@@ -162,8 +167,8 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
                         DoneCallback done) override {
     auto output_shape = c->input(0).shape();
     output_shape.set_dim(
-        0, output_shape.dim_size(0) * col_params_.group.group_size);
-    col_params_.instance.shape = output_shape;
+        0, output_shape.dim_size(0) * col_params_->group.group_size);
+    col_params_->instance.shape = output_shape;
 
     // Allocate output on the first pass through this function.  This must be
     // done immediately, while we're still in the executor thread.  Otherwise
@@ -173,24 +178,24 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
       // Allocate the output tensor.
       Tensor* output = nullptr;
       OP_REQUIRES_OK_ASYNC(
-          c, c->allocate_output(0, col_params_.instance.shape, &output), done);
+          c, c->allocate_output(0, col_params_->instance.shape, &output), done);
     }
     if (!CanProceedWithCompute(c, col_exec, done)) return;
 
-    auto actual_done = [c, group_key = col_params_.group.group_key,
-                        instance_key = col_params_.instance.instance_key,
-                        done](const Status& s) {
+    auto actual_done = [c, col_params = col_params_, done](const Status& s) {
       VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync done for collective "
               << c->op_kernel().name() << " device " << c->device()->name()
-              << " group " << group_key << " instance " << instance_key
-              << " status " << s;
+              << " group " << col_params->group.group_key << " instance "
+              << col_params->instance.instance_key << " status " << s;
       OP_REQUIRES_OK_ASYNC(c, s, done);
       done();
+      col_params->Unref();
     };
     VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync start for collective "
-            << col_params_.name << " device " << c->device()->name()
-            << " group " << col_params_.group.group_key << " instance "
-            << col_params_.instance.instance_key;
+            << col_params_->name << " device " << c->device()->name()
+            << " group " << col_params_->group.group_key << " instance "
+            << col_params_->instance.instance_key;
+    col_params_->Ref();
     col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
   }
 
@@ -207,18 +212,18 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
  public:
   explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
       : CollectiveOpV1Kernel(c) {
-    col_params_.instance.type = REDUCTION_COLLECTIVE;
-    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+    col_params_->instance.type = REDUCTION_COLLECTIVE;
+    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
     OP_REQUIRES(
-        c, col_params_.group.group_size > 0,
+        c, col_params_->group.group_size > 0,
         errors::InvalidArgument("group_size must be positive integer but got ",
-                                col_params_.group.group_size));
-    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+                                col_params_->group.group_size));
+    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
     OP_REQUIRES_OK(
-        c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+        c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
     OP_REQUIRES_OK(
         c, c->GetAttr("subdiv_offsets",
-                      &col_params_.instance.impl_details.subdiv_offsets));
+                      &col_params_->instance.impl_details.subdiv_offsets));
     string merge_op_name;
     OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
     if (merge_op_name == "Max") {
@@ -232,24 +237,26 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
                 errors::InvalidArgument(
                     "final_op must be one of {\"Id\", \"Div\"} but got ",
                     final_op_name));
-    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
     OP_REQUIRES_OK(c, c->GetAttr("wait_for", &dependencies_));
     OP_REQUIRES_OK(
         c, c->GetAttr("communication_hint",
-                      &col_params_.instance.impl_details.communication_hint));
+                      &col_params_->instance.impl_details.communication_hint));
     OP_REQUIRES_OK(
         c, c->GetAttr("timeout_seconds",
-                      &col_params_.instance.impl_details.timeout_seconds));
-    VLOG(2) << "CollectiveReduce instance " << col_params_.instance.instance_key
-            << " merge_op " << merge_op_name << " final_op " << final_op_name
+                      &col_params_->instance.impl_details.timeout_seconds));
+    VLOG(2) << "CollectiveReduce instance "
+            << col_params_->instance.instance_key << " merge_op "
+            << merge_op_name << " final_op " << final_op_name
             << " communication_hint "
-            << col_params_.instance.impl_details.communication_hint
-            << " timeout " << col_params_.instance.impl_details.timeout_seconds;
+            << col_params_->instance.impl_details.communication_hint
+            << " timeout "
+            << col_params_->instance.impl_details.timeout_seconds;
 
     const NodeDef& real_node = c->def();
-    col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
-                                       merge_op_name, ",", final_op_name, ")");
-    col_params_.group.device_type = c->device_type();
+    col_params_->name = strings::StrCat(real_node.name(), ": Reduce(",
+                                        merge_op_name, ",", final_op_name, ")");
+    col_params_->group.device_type = c->device_type();
 
     // Find the OpKernels by name, type and device type.
     NodeDef sub_node;
@@ -257,12 +264,12 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
     sub_node.add_input(real_node.input(0));
     sub_node.add_input(real_node.input(0));
     sub_node.set_device(real_node.device());
-    SetAttrValue(col_params_.instance.data_type,
+    SetAttrValue(col_params_->instance.data_type,
                  &(*sub_node.mutable_attr())["T"]);
     merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
     final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
-    col_params_.merge_op = merge_op_.get();
-    col_params_.final_op = final_op_.get();
+    col_params_->merge_op = merge_op_.get();
+    col_params_->final_op = final_op_.get();
   }
 
  protected:
@@ -279,24 +286,24 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
                            c->forward_input_or_allocate_output(
                                {0}, 0, c->input(0).shape(), &output),
                            done);
-      col_params_.instance.shape = c->input(0).shape();
+      col_params_->instance.shape = c->input(0).shape();
     }
     if (!CanProceedWithCompute(c, col_exec, done)) return;
 
-    auto actual_done = [c, group_key = col_params_.group.group_key,
-                        instance_key = col_params_.instance.instance_key,
-                        done](const Status& s) {
+    auto actual_done = [c, col_params = col_params_, done](const Status& s) {
       VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync done for collective "
               << c->op_kernel().name() << " device " << c->device()->name()
-              << " group " << group_key << " instance " << instance_key
-              << " status " << s;
+              << " group " << col_params->group.group_key << " instance "
+              << col_params->instance.instance_key << " status " << s;
       OP_REQUIRES_OK_ASYNC(c, s, done);
       done();
+      col_params->Unref();
     };
     VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync start for collective "
-            << col_params_.name << " device " << c->device()->name()
-            << " group " << col_params_.group.group_key << " instance "
-            << col_params_.instance.instance_key;
+            << col_params_->name << " device " << c->device()->name()
+            << " group " << col_params_->group.group_key << " instance "
+            << col_params_->instance.instance_key;
+    col_params_->Ref();
     col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
   }
 
@@ -315,29 +322,29 @@ class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
  public:
   explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
       : CollectiveOpV1Kernel(c) {
-    col_params_.instance.type = BROADCAST_COLLECTIVE;
-    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+    col_params_->instance.type = BROADCAST_COLLECTIVE;
+    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
     OP_REQUIRES(
-        c, col_params_.group.group_size > 0,
+        c, col_params_->group.group_size > 0,
         errors::InvalidArgument("group_size must be positive integer but got ",
-                                col_params_.group.group_size));
-    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+                                col_params_->group.group_size));
+    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
     OP_REQUIRES_OK(
-        c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
-    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
-    OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
+        c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
+    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
+    OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
     OP_REQUIRES_OK(
         c, c->GetAttr("communication_hint",
-                      &col_params_.instance.impl_details.communication_hint));
+                      &col_params_->instance.impl_details.communication_hint));
     OP_REQUIRES_OK(
         c, c->GetAttr("timeout_seconds",
-                      &col_params_.instance.impl_details.timeout_seconds));
-    col_params_.is_source = true;
-    col_params_.instance.impl_details.subdiv_offsets = {0};
+                      &col_params_->instance.impl_details.timeout_seconds));
+    col_params_->is_source = true;
+    col_params_->instance.impl_details.subdiv_offsets = {0};
 
-    col_params_.name =
-        strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
-    col_params_.group.device_type = c->device_type();
+    col_params_->name =
+        strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
+    col_params_->group.device_type = c->device_type();
   }
 
  protected:
@@ -352,30 +359,30 @@ class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
       Tensor* output = nullptr;
       OP_REQUIRES_OK_ASYNC(c,
                            c->forward_input_or_allocate_output(
-                               {0}, 0, col_params_.instance.shape, &output),
+                               {0}, 0, col_params_->instance.shape, &output),
                            done);
     }
     if (!CanProceedWithCompute(c, col_exec, done)) return;
     OP_REQUIRES_ASYNC(
-        c, col_params_.instance.shape.IsSameSize(c->input(0).shape()),
-        errors::Internal("Declared shape of op ", col_params_.name,
+        c, col_params_->instance.shape.IsSameSize(c->input(0).shape()),
+        errors::Internal("Declared shape of op ", col_params_->name,
                          " does not match shape of input"),
         done);
 
-    auto actual_done = [c, group_key = col_params_.group.group_key,
-                        instance_key = col_params_.instance.instance_key,
-                        done](const Status& s) {
+    auto actual_done = [c, col_params = col_params_, done](const Status& s) {
       VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync done for collective "
               << c->op_kernel().name() << " device " << c->device()->name()
-              << " group " << group_key << " instance " << instance_key
-              << " status " << s;
+              << " group " << col_params->group.group_key << " instance "
+              << col_params->instance.instance_key << " status " << s;
       OP_REQUIRES_OK_ASYNC(c, s, done);
       done();
+      col_params->Unref();
     };
     VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync start for collective "
-            << col_params_.name << " device " << c->device()->name()
-            << " group " << col_params_.group.group_key << " instance "
-            << col_params_.instance.instance_key;
+            << col_params_->name << " device " << c->device()->name()
+            << " group " << col_params_->group.group_key << " instance "
+            << col_params_->instance.instance_key;
+    col_params_->Ref();
     col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
   }
 
@@ -392,29 +399,29 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
  public:
   explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
       : CollectiveOpV1Kernel(c) {
-    col_params_.instance.type = BROADCAST_COLLECTIVE;
-    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+    col_params_->instance.type = BROADCAST_COLLECTIVE;
+    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
     OP_REQUIRES(
-        c, col_params_.group.group_size > 0,
+        c, col_params_->group.group_size > 0,
         errors::InvalidArgument("group_size must be positive integer but got ",
-                                col_params_.group.group_size));
-    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+                                col_params_->group.group_size));
+    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
     OP_REQUIRES_OK(
-        c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
-    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
-    OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
+        c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
+    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
+    OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
     OP_REQUIRES_OK(
         c, c->GetAttr("communication_hint",
-                      &col_params_.instance.impl_details.communication_hint));
+                      &col_params_->instance.impl_details.communication_hint));
     OP_REQUIRES_OK(
         c, c->GetAttr("timeout_seconds",
-                      &col_params_.instance.impl_details.timeout_seconds));
-    col_params_.is_source = false;
-    col_params_.instance.impl_details.subdiv_offsets = {0};
+                      &col_params_->instance.impl_details.timeout_seconds));
+    col_params_->is_source = false;
+    col_params_->instance.impl_details.subdiv_offsets = {0};
 
-    col_params_.name =
-        strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
-    col_params_.group.device_type = c->device_type();
+    col_params_->name =
+        strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
+    col_params_->group.device_type = c->device_type();
   }
 
  protected:
@@ -428,24 +435,24 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
       // No input, so must allocate output.
       Tensor* output = nullptr;
       OP_REQUIRES_OK_ASYNC(
-          c, c->allocate_output(0, col_params_.instance.shape, &output), done);
+          c, c->allocate_output(0, col_params_->instance.shape, &output), done);
     }
     if (!CanProceedWithCompute(c, col_exec, done)) return;
 
-    auto actual_done = [c, group_key = col_params_.group.group_key,
-                        instance_key = col_params_.instance.instance_key,
-                        done](const Status& s) {
+    auto actual_done = [c, col_params = col_params_, done](const Status& s) {
       VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync done for collective "
               << c->op_kernel().name() << " device " << c->device()->name()
-              << " group " << group_key << " instance_key " << instance_key
-              << " status  " << s;
+              << " group " << col_params->group.group_key << " instance_key "
+              << col_params->instance.instance_key << " status  " << s;
       OP_REQUIRES_OK_ASYNC(c, s, done);
       done();
+      col_params->Unref();
     };
     VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync start for collective "
-            << col_params_.name << " device " << c->device()->name()
-            << " group " << col_params_.group.group_key << " instance "
-            << col_params_.instance.instance_key;
+            << col_params_->name << " device " << c->device()->name()
+            << " group " << col_params_->group.group_key << " instance "
+            << col_params_->instance.instance_key;
+    col_params_->Ref();
     col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
   }
 
@@ -534,8 +541,8 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
             << col_params->instance.instance_key;
 
     auto done_with_cleanup = [col_params, done = std::move(done)]() {
-      delete col_params;
       done();
+      col_params->Unref();
     };
 
     // Allocate the output tensor, trying to reuse the input.
@@ -577,7 +584,7 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
                       << " group " << col_params->group.group_key
                       << " instance " << col_params->instance.instance_key;
               col_exec->ExecuteAsync(
-                  c, *col_params,
+                  c, col_params,
                   CollectiveKey(c, col_params->group.group_key,
                                 col_params->instance.instance_key),
                   actual_done);
@@ -673,8 +680,8 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
     col_params->instance.shape = output_shape;
 
     auto done_with_cleanup = [col_params, done = std::move(done)]() {
-      delete col_params;
       done();
+      col_params->Unref();
     };
 
     Tensor* output = nullptr;
@@ -714,7 +721,7 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
                       << " group " << col_params->group.group_key
                       << " instance " << col_params->instance.instance_key;
               col_exec->ExecuteAsync(
-                  c, *col_params,
+                  c, col_params,
                   CollectiveKey(c, col_params->group.group_key,
                                 col_params->instance.instance_key),
                   actual_done);
@@ -797,8 +804,8 @@ class CollectiveBcastSendV2OpKernel : public AsyncOpKernel {
             << col_params->instance.instance_key;
 
     auto done_with_cleanup = [col_params, done = std::move(done)]() {
-      delete col_params;
       done();
+      col_params->Unref();
     };
 
     // Allocate the output tensor, trying to reuse the input.
@@ -840,7 +847,7 @@ class CollectiveBcastSendV2OpKernel : public AsyncOpKernel {
                       << " group " << col_params->group.group_key
                       << " instance " << col_params->instance.instance_key;
               col_exec->ExecuteAsync(
-                  c, *col_params,
+                  c, col_params,
                   CollectiveKey(c, col_params->group.group_key,
                                 col_params->instance.instance_key),
                   actual_done);
@@ -905,8 +912,8 @@ class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel {
 
     auto col_params = new CollectiveParams();
     auto done_with_cleanup = [col_params, done = std::move(done)]() {
-      delete col_params;
       done();
+      col_params->Unref();
     };
 
     OP_REQUIRES_OK_ASYNC(
@@ -969,7 +976,7 @@ class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel {
                       << " group " << col_params->group.group_key
                       << " instance " << col_params->instance.instance_key;
               col_exec->ExecuteAsync(
-                  c, *col_params,
+                  c, col_params,
                   CollectiveKey(c, col_params->group.group_key,
                                 col_params->instance.instance_key),
                   actual_done);
diff --git a/tensorflow/core/nccl/collective_communicator.cc b/tensorflow/core/nccl/collective_communicator.cc
index 233f1ecefd3..2f0659eb121 100644
--- a/tensorflow/core/nccl/collective_communicator.cc
+++ b/tensorflow/core/nccl/collective_communicator.cc
@@ -69,17 +69,17 @@ std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator() {
 
 void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
                                StatusCallback done) {
-  const CollectiveParams& col_params = col_ctx->col_params;
-  const int num_global_devices = col_params.group.group_size;
-  const int num_local_devices = col_params.group.num_devices_per_task.at(
-      col_params.group.task_names[col_params.default_rank]);
+  const CollectiveParams* col_params = col_ctx->col_params;
+  const int num_global_devices = col_params->group.group_size;
+  const int num_local_devices = col_params->group.num_devices_per_task.at(
+      col_params->group.task_names[col_params->default_rank]);
   const string nccl_collective_key =
       NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id);
   auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream();
   auto* gpu_info = col_ctx->op_ctx->device()->tensorflow_gpu_device_info();
   auto participant = absl::make_unique<NcclManager::Participant>(
       compute_stream->parent(), compute_stream, gpu_info, col_ctx->input,
-      col_ctx->output, col_ctx->col_params.default_rank,
+      col_ctx->output, col_ctx->col_params->default_rank,
       /*done_callback=*/nullptr);
   CancellationManager* cancel_mgr = col_ctx->op_ctx->cancellation_manager();
   if (cancel_mgr == nullptr) {
@@ -105,15 +105,24 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
   }
   NcclManager::Context context(
       nccl_collective_key, num_local_devices, num_global_devices,
-      col_params.group.runtime_details.communicator_key,
-      col_params.source_rank);
-  VLOG(1) << "NcclCommunicator::Enqueue type " << col_params.instance.type
-          << " num_tasks " << col_params.group.num_tasks << " current task "
-          << col_params.group.task_names[col_params.default_rank]
+      col_params->group.runtime_details.communicator_key,
+      col_params->source_rank);
+  VLOG(1) << "NcclCommunicator::Enqueue type " << col_params->instance.type
+          << " num_tasks " << col_params->group.num_tasks << " current task "
+          << col_params->group.task_names[col_params->default_rank]
           << " num local devices " << num_local_devices
           << " num global devices " << num_global_devices << " device "
           << col_ctx->device_name << " instance "
-          << col_params.instance.instance_key;
+          << col_params->instance.instance_key;
+  // Hold a ref to col_params for the rest of this function.
+  // NOTE: an alternate design can be one in which CollectiveParams is not
+  // refcounted.  In such a design, we would need to ensure that the
+  // done_callback of each participant is called only after this function is
+  // done with accessing the params.  This would likely require some
+  // coordination mechanism, and may even require the participant thread to
+  // block until after UnblockDependencies is called below.
+  col_params->Ref();
+  core::ScopedUnref unref(col_params);
   // `AddTo*` performs consistency checks for the NCCL call and enqueues the
   // `Participant` struct locally.  When all local participants with this
   // `nccl_collective_key` have called `AddToAllReduce` and
@@ -123,10 +132,11 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
   // The `NcclManager` uses a dedicated CUDA stream for NCCL kernels.  At this
   // point, it synchronizes the NCCL stream with the compute stream, and then
   // enqueues the NCCL kernel on the NCCL stream.
-  switch (col_params.instance.type) {
+  switch (col_params->instance.type) {
     case REDUCTION_COLLECTIVE: {
       ncclRedOp_t reduction_op;
-      Status s = ReductionOp(col_params.merge_op->type_string(), &reduction_op);
+      Status s =
+          ReductionOp(col_params->merge_op->type_string(), &reduction_op);
       if (!s.ok()) {
         participant->done_callback(s);
         return;
@@ -140,7 +150,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
       break;
     }
     case BROADCAST_COLLECTIVE: {
-      if (col_params.is_source) {
+      if (col_params->is_source) {
         nccl_manager_.AddBroadcastSend(std::move(participant), context);
       } else {
         nccl_manager_.AddBroadcastRecv(std::move(participant), context);
@@ -149,7 +159,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
     }
     default: {
       participant->done_callback(errors::Internal("Unexpected CollectiveType ",
-                                                  col_params.instance.type));
+                                                  col_params->instance.type));
       return;
     }
   }
@@ -175,7 +185,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
     // ready to go.
     profiler::TraceMe activity("WaitForDependencies",
                                profiler::TraceMeLevel::kInfo);
-    col_ctx->col_exec->WaitForDependencies(col_params);
+    col_ctx->col_exec->WaitForDependencies(*col_params);
     nccl_manager_.SignalMultiNodeReady(nccl_collective_key);
   }
   {
@@ -184,7 +194,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
     // implementation of `UnblockDependencies` keeps track of the number of
     // devices that have launched.
     profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
-    col_ctx->col_exec->UnblockDependencies(col_params);
+    col_ctx->col_exec->UnblockDependencies(*col_params);
   }
 }