diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 93c81052cd8..482b5f37579 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -5181,14 +5181,26 @@ cc_library(
     hdrs = ["collective_ops_utils.h"],
     deps = [
         ":computation_placer",
+        ":global_device_id",
         ":hlo",
+        ":pattern_matcher",
         "//tensorflow/compiler/xla:executable_run_options",
         "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla/service:pattern_matcher",
         "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
-        "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",  # fixdeps: keep
-        "//tensorflow/stream_executor/lib",
+    ],
+)
+
+tf_cc_test(
+    name = "collective_ops_utils_test",
+    srcs = ["collective_ops_utils_test.cc"],
+    deps = [
+        ":collective_ops_utils",
+        ":computation_placer",
+        ":global_device_id",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:test",
     ],
 )
 
diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.cc b/tensorflow/compiler/xla/service/collective_ops_utils.cc
index 331f26ca9c3..c2da4834943 100644
--- a/tensorflow/compiler/xla/service/collective_ops_utils.cc
+++ b/tensorflow/compiler/xla/service/collective_ops_utils.cc
@@ -15,6 +15,8 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
 
+#include "tensorflow/compiler/xla/service/global_device_id.h"
+
 namespace xla {
 
 absl::optional<ReductionKind> MatchReductionComputation(
@@ -48,39 +50,59 @@ absl::optional<ReductionKind> MatchReductionComputation(
   }
 }
 
-StatusOr<std::vector<int64>> GetParticipatingReplicas(
-    GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
-    int64 total_replica_count, const DeviceAssignment& device_assn) {
-  std::vector<int64> participating_replicas;
-
-  // Empty replica_groups() means that all replicas participate in one big
-  // group.
+StatusOr<std::vector<int>> GetParticipatingReplicas(
+    int replica_id, int total_replica_count,
+    absl::Span<const ReplicaGroup> replica_groups) {
+  // Empty replica_groups() means that all replicas participate.
   if (replica_groups.empty()) {
-    participating_replicas.resize(total_replica_count);
-    absl::c_iota(participating_replicas, 0);
-    return participating_replicas;
+    std::vector<int> all_replicas(total_replica_count);
+    absl::c_iota(all_replicas, 0);
+    return all_replicas;
   }
 
-  // Use the DeviceAssignment to figure out our replica-id.
-  TF_ASSIGN_OR_RETURN(int replica_id,
-                      device_assn.ReplicaIdForDevice(device_id));
-
   // Figure out the other replicas that go together with this one.
   absl::optional<ReplicaGroup> replica_group;
   for (const ReplicaGroup& g : replica_groups) {
     if (absl::c_linear_search(g.replica_ids(), replica_id)) {
-      CHECK(!replica_group.has_value())
+      TF_RET_CHECK(!replica_group.has_value())
           << "Replica " << replica_id << " appears twice in replica groups";
       replica_group = g;
     }
   }
-  CHECK(replica_group.has_value())
-      << "Replica " << replica_id << " doesn't appear in replica groups? ";
+  TF_RET_CHECK(replica_group.has_value())
+      << "Replica " << replica_id << " doesn't appear in replica groups";
+  return std::vector<int>(replica_group->replica_ids().begin(),
+                          replica_group->replica_ids().end());
+}
 
-  participating_replicas.insert(participating_replicas.begin(),
-                                replica_group->replica_ids().begin(),
-                                replica_group->replica_ids().end());
-  return participating_replicas;
+StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
+    GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
+    int total_replica_count, absl::Span<const ReplicaGroup> replica_groups) {
+  std::vector<GlobalDeviceId> participants;
+  // Fast path for common case, avoiding logical IDs lookup.
+  if (replica_groups.empty() && device_assignment.computation_count() == 1) {
+    participants.reserve(total_replica_count);
+    for (int replica_id = 0; replica_id < total_replica_count; ++replica_id) {
+      participants.emplace_back(
+          device_assignment(replica_id, /*computation_id=*/0));
+    }
+    return participants;
+  }
+
+  std::pair<int, int> logical_ids;
+  TF_ASSIGN_OR_RETURN(logical_ids,
+                      device_assignment.LogicalIdsForDevice(device_id));
+  int replica_id = logical_ids.first;
+  int computation_id = logical_ids.second;
+  TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
+                      GetParticipatingReplicas(replica_id, total_replica_count,
+                                               replica_groups));
+
+  participants.reserve(participating_replicas.size());
+  for (int replica_id : participating_replicas) {
+    participants.emplace_back(device_assignment(replica_id, computation_id));
+  }
+  return participants;
 }
 
 }  // end namespace xla
diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h
index c67d0739f5c..d7f63cb2663 100644
--- a/tensorflow/compiler/xla/service/collective_ops_utils.h
+++ b/tensorflow/compiler/xla/service/collective_ops_utils.h
@@ -21,13 +21,13 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/executable_run_options.h"
 #include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/service/global_device_id.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/core/lib/core/blocking_counter.h"
-#include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace xla {
 
@@ -37,11 +37,17 @@ enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
 absl::optional<ReductionKind> MatchReductionComputation(
     const HloComputation* computation);
 
-// Figures out which devices (named by their replica-ids) are participating in
-// the collective subgroup that contains device_id.
-StatusOr<std::vector<int64>> GetParticipatingReplicas(
-    GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
-    int64 total_replica_count, const DeviceAssignment& device_assn);
+// Figures out which replicas are participating in the collective subgroup.
+// An empty `replica_groups` indicates that all replicas are participating.
+StatusOr<std::vector<int>> GetParticipatingReplicas(
+    int replica_id, int total_replica_count,
+    absl::Span<const ReplicaGroup> replica_groups);
+
+// Figures out which devices are participating in the collective subgroup.
+// An empty `replica_groups` indicates that all replicas are participating.
+StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
+    GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
+    int total_replica_count, absl::Span<const ReplicaGroup> replica_groups);
 
 // Key that identifies a particular Rendezvous object in our global hashtable.
 // This determines which calls to ExecuteOnStream communicate with each other.
diff --git a/tensorflow/compiler/xla/service/collective_ops_utils_test.cc b/tensorflow/compiler/xla/service/collective_ops_utils_test.cc
new file mode 100644
index 00000000000..c08113fb453
--- /dev/null
+++ b/tensorflow/compiler/xla/service/collective_ops_utils_test.cc
@@ -0,0 +1,109 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
+
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/service/global_device_id.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+TEST(CollectiveOpsUtilsTest, GetParticipatingReplicas_NoReplicaGroups) {
+  std::vector<int> actual =
+      GetParticipatingReplicas(
+          /*replica_id=*/0, /*total_replica_count=*/3, /*replica_groups=*/{})
+          .ConsumeValueOrDie();
+  std::vector<int> expected = {0, 1, 2};
+  EXPECT_EQ(actual, expected);
+}
+
+TEST(CollectiveOpsUtilsTest, GetParticipatingReplicas_ReplicaGroups) {
+  std::vector<ReplicaGroup> replica_groups(3);
+  replica_groups[0].add_replica_ids(0);
+  replica_groups[0].add_replica_ids(4);
+  replica_groups[1].add_replica_ids(1);
+  replica_groups[1].add_replica_ids(5);
+  replica_groups[2].add_replica_ids(2);
+  replica_groups[2].add_replica_ids(3);
+
+  std::vector<int> actual =
+      GetParticipatingReplicas(
+          /*replica_id=*/1, /*total_replica_count=*/6, replica_groups)
+          .ConsumeValueOrDie();
+  std::vector<int> expected = {1, 5};
+  EXPECT_EQ(actual, expected);
+}
+
+TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_NoReplicaGroups) {
+  DeviceAssignment device_assignment(/*replica_count=*/3,
+                                     /*computation_count=*/1);
+  device_assignment(0, 0) = 42;
+  device_assignment(1, 0) = 43;
+  device_assignment(2, 0) = 44;
+
+  std::vector<GlobalDeviceId> actual =
+      GetParticipatingDevices(GlobalDeviceId(42), device_assignment,
+                              /*total_replica_count=*/3, /*replica_groups=*/{})
+          .ConsumeValueOrDie();
+  std::vector<GlobalDeviceId> expected = {
+      GlobalDeviceId(42), GlobalDeviceId(43), GlobalDeviceId(44)};
+  EXPECT_EQ(actual, expected);
+}
+
+TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_ReplicaGroups) {
+  DeviceAssignment device_assignment(/*replica_count=*/4,
+                                     /*computation_count=*/1);
+  device_assignment(0, 0) = 42;
+  device_assignment(1, 0) = 43;
+  device_assignment(2, 0) = 44;
+  device_assignment(3, 0) = 45;
+
+  std::vector<ReplicaGroup> replica_groups(2);
+  replica_groups[0].add_replica_ids(0);
+  replica_groups[0].add_replica_ids(3);
+  replica_groups[1].add_replica_ids(1);
+  replica_groups[1].add_replica_ids(2);
+
+  std::vector<GlobalDeviceId> actual =
+      GetParticipatingDevices(GlobalDeviceId(42), device_assignment,
+                              /*total_replica_count=*/4, replica_groups)
+          .ConsumeValueOrDie();
+  std::vector<GlobalDeviceId> expected = {GlobalDeviceId(42),
+                                          GlobalDeviceId(45)};
+  EXPECT_EQ(actual, expected);
+}
+
+TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_MultipleComputations) {
+  DeviceAssignment device_assignment(/*replica_count=*/2,
+                                     /*computation_count=*/2);
+  device_assignment(0, 0) = 42;
+  device_assignment(1, 0) = 43;
+  device_assignment(0, 1) = 44;
+  device_assignment(1, 1) = 45;
+
+  std::vector<GlobalDeviceId> actual =
+      GetParticipatingDevices(GlobalDeviceId(44), device_assignment,
+                              /*total_replica_count=*/2, /*replica_groups=*/{})
+          .ConsumeValueOrDie();
+  std::vector<GlobalDeviceId> expected = {GlobalDeviceId(44),
+                                          GlobalDeviceId(45)};
+  EXPECT_EQ(actual, expected);
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index a8cc531e5c6..d4d78f5ac12 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -167,7 +167,7 @@ struct AllToAllParticipantData : xla::ParticipantData {
 
   // Replica ids participating in AllToAll, concatenation happens in the order
   // of appearence.
-  std::vector<xla::int64> replica_ids_to_copy_to;
+  std::vector<int> replica_ids_to_copy_to;
 
   std::string ToString() const override {
     auto addr_formatter = [](std::string* out,
@@ -348,7 +348,7 @@ class CpuAllToAllRendezvous
         replica_id_map[p.replica_id] = pos;
       }
 
-      const std::vector<xla::int64>& replica_ids_to_copy_to =
+      const std::vector<int>& replica_ids_to_copy_to =
           participants_[0].replica_ids_to_copy_to;
 
       // Replica id -> rank
@@ -598,26 +598,19 @@ xla::RendezvousKey GetRendezvousKey(
     xla::int64 op_id) {
   const xla::DeviceAssignment& device_assignment =
       *run_options->device_assignment();
-  xla::int32 replica_count = device_assignment.replica_count();
   int device_ordinal = GetDeviceOrdinal(run_options);
-  CHECK_EQ(device_assignment.computation_count(), 1);
-  std::vector<xla::int64> participating_replicas =
-      xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group,
-                                    replica_count,
-                                    *run_options->device_assignment())
-          .ValueOrDie();
   xla::RendezvousKey::CollectiveOpKind op_kind =
       channel_id_present ? xla::RendezvousKey::kCrossModule
                          : xla::RendezvousKey::kCrossReplica;
-  std::vector<xla::GlobalDeviceId> participating_devices;
-  participating_devices.reserve(participating_replicas.size());
-  for (xla::int64 replica : participating_replicas) {
-    participating_devices.push_back(
-        xla::GlobalDeviceId(device_assignment(replica, 0)));
-  }
-  return xla::RendezvousKey{
-      run_options->run_id(), std::move(participating_devices),
-      static_cast<int>(participating_replicas.size()), op_kind, op_id};
+  std::vector<xla::GlobalDeviceId> participating_devices =
+      xla::GetParticipatingDevices(xla::GlobalDeviceId(device_ordinal),
+                                   device_assignment,
+                                   device_assignment.replica_count(), group)
+          .ValueOrDie();
+  int num_local_participants = participating_devices.size();
+  return xla::RendezvousKey{run_options->run_id(),
+                            std::move(participating_devices),
+                            num_local_participants, op_kind, op_id};
 }
 
 }  // namespace
@@ -644,9 +637,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
   participant.replica_id = replica_id;
   participant.replica_ids_to_copy_to =
       xla::GetParticipatingReplicas(
-          xla::GlobalDeviceId(device_ordinal), group,
-          run_options->device_assignment()->replica_count(),
-          *run_options->device_assignment())
+          replica_id, run_options->device_assignment()->replica_count(), group)
           .ValueOrDie();
   for (int i = 0; i < num_buffers; i++) {
     participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
index 71be037856e..9b797f7e9e6 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc
@@ -24,6 +24,8 @@ limitations under the License.
 
 #include "absl/container/flat_hash_set.h"
 #include "absl/strings/str_format.h"
+#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
+#include "tensorflow/compiler/xla/service/global_device_id.h"
 #if GOOGLE_CUDA
 #include "third_party/nccl/nccl.h"
 #elif TENSORFLOW_USE_ROCM
@@ -132,28 +134,18 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
   int device_ordinal = executor->device_ordinal();
   TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
                       params.GetGlobalDeviceId());
-  // Determines the set of global and local devices that are participating in
-  // the same collective group as the caller.
+
   TF_ASSIGN_OR_RETURN(
-      std::vector<int64> participating_replicas,
-      GetParticipatingReplicas(global_device_id, config_.replica_groups,
-                               config_.replica_count, *params.device_assn));
-  if (IsGlobalNcclConfig() &&
-      participating_replicas.size() != config_.replica_count) {
+      std::vector<GlobalDeviceId> participants,
+      GetParticipatingDevices(global_device_id, *params.device_assn,
+                              config_.replica_count, config_.replica_groups));
+
+  if (IsGlobalNcclConfig() && (participants.size() != config_.replica_count)) {
     return InvalidArgument(
         "Partial replica groups are not allowed when using NCCL_COMM_ID "
         "environment configuration.");
   }
 
-  TF_RET_CHECK(params.device_assn->computation_count() == 1)
-      << params.device_assn->ToString();
-  std::vector<GlobalDeviceId> participants;
-  participants.reserve(participating_replicas.size());
-  for (int64 replica : participating_replicas) {
-    participants.emplace_back(
-        (*params.device_assn)(replica, /*computation=*/0));
-  }
-
   TF_ASSIGN_OR_RETURN(
       std::vector<LocalParticipant> local_participants,
       GetLocalParticipants(participants, params.gpu_global_device_ids));