diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 2c8736ad42c..40ecd80e4c9 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -191,7 +191,8 @@ Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
 // This is just a minimal set of checks; some instructions may have additional
 // requirements.  For example, all-to-all requires that all ReplicaGroups have
 // the same number of replicas, but that isn't checked here.
-static Status CheckReplicaGroups(HloInstruction* hlo) {
+static Status CheckReplicaGroups(HloInstruction* hlo,
+                                 bool use_global_device_ids) {
   std::set<int64> replicas_seen;
   for (const ReplicaGroup& g : hlo->replica_groups()) {
     if (g.replica_ids().empty()) {
@@ -214,19 +215,24 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
     }
   }
 
+  // If use_global_device_ids() is set, replica_groups cannot be empty.
   // When the channel_id() or use_global_device_ids() is set, device ids in
   // ReplicaGroup config no longer only mean replica ids. So we skip the check
   // on the replica count.
+  if (use_global_device_ids) {
+    if (hlo->replica_groups().empty()) {
+      return InternalError(
+          "Replica group must be specified when use_global_device_ids is true");
+    }
+    // No need to check replica_count.
+    return Status::OK();
+  }
+
   if (auto channel_instr = DynCast<HloChannelInstruction>(hlo)) {
     if (channel_instr->channel_id()) {
       return Status::OK();
     }
   }
-  if (auto all_reduce = DynCast<HloAllReduceInstruction>(hlo)) {
-    if (all_reduce->use_global_device_ids()) {
-      return Status::OK();
-    }
-  }
 
   int64 replica_count = hlo->GetModule()->config().replica_count();
   if (replica_count != 1 && !replicas_seen.empty() &&
@@ -242,14 +248,10 @@ static Status CheckReplicaGroups(HloInstruction* hlo) {
 
 Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
   auto ag = Cast<HloAllGatherInstruction>(hlo);
-  TF_RETURN_IF_ERROR(CheckReplicaGroups(ag));
+  TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, ag->use_global_device_ids()));
   TF_RET_CHECK(ag->all_gather_dimension() >= 0);
   TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank());
   TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank());
-  if (ag->use_global_device_ids() && ag->replica_groups().empty()) {
-    return InternalError(
-        "Replica group must be specified when use_global_device_ids is true");
-  }
 
   int64 shard_count = CeilOfRatio(
       ag->shape().dimensions(ag->all_gather_dimension()),
@@ -274,20 +276,21 @@ Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
                             shard_count));
 }
 
-Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) {
-  TF_RETURN_IF_ERROR(CheckReplicaGroups(crs));
+Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
+  auto ar = Cast<HloAllReduceInstruction>(hlo);
+  TF_RETURN_IF_ERROR(CheckReplicaGroups(ar, ar->use_global_device_ids()));
 
   std::vector<const Shape*> operand_shapes;
-  for (const HloInstruction* operand : crs->operands()) {
+  for (const HloInstruction* operand : hlo->operands()) {
     operand_shapes.push_back(&operand->shape());
   }
-  return CheckShape(crs, ShapeInference::InferAllReduceShape(operand_shapes));
+  return CheckShape(hlo, ShapeInference::InferAllReduceShape(operand_shapes));
 }
 
 Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
-  TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo));
-
   auto* all_to_all = Cast<HloAllToAllInstruction>(hlo);
+  TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, /*use_global_device_ids=*/false));
+
   TF_RET_CHECK(all_to_all != nullptr);
   if (all_to_all->split_dimension()) {
     if (hlo->replica_groups().empty()) {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 03fca5938ff..980fdd170b4 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -57,7 +57,7 @@ class ShapeVerifier : public DfsHloVisitor {
   Status HandleCholesky(HloInstruction* hlo) override;
   Status HandleTriangularSolve(HloInstruction* hlo) override;
   Status HandleAllGather(HloInstruction* hlo) override;
-  Status HandleAllReduce(HloInstruction* crs) override;
+  Status HandleAllReduce(HloInstruction* hlo) override;
   Status HandleAllToAll(HloInstruction* hlo) override;
   Status HandleCollectivePermute(HloInstruction* hlo) override;
   Status HandleCollectivePermuteStart(HloInstruction* hlo) override;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index c6c09e3dee1..7a95cdba7c3 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -1292,5 +1292,29 @@ TEST_F(HloVerifierTest, ComparisonTypePred) {
               HasSubstr("Expected comparison type UNSIGNED"));
 }
 
+TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
+  const char* const hlo_string = R"(
+  HloModule Module
+  add {
+    lhs = f32[] parameter(0)
+    rhs = f32[] parameter(1)
+    ROOT add = f32[] add(lhs, rhs)
+  }
+
+  ENTRY CRS {
+    input = f32[8]{0} parameter(0)
+    ROOT crs = f32[8]{0} all-reduce(input), replica_groups={},
+                         use_global_device_ids=true, to_apply=add
+  })";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnUnverifiedModule(hlo_string));
+
+  auto status = verifier().Run(module.get()).status();
+  ASSERT_FALSE(status.ok());
+  EXPECT_THAT(status.error_message(),
+              HasSubstr("Replica group must be specified when "
+                        "use_global_device_ids is true"));
+}
+
 }  // namespace
 }  // namespace xla