diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
index 18c76915284..3c2344be1e4 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir
@@ -515,6 +515,21 @@ func @multiple_replicated_interleaved(%arg0: !tf_res) {
 // -----
 
 
+// Test cluster that is replicated but has a non TPUReplicatedOutput consumer.
+// CHECK-LABEL: func @replicated_non_replicated_output
+func @replicated_non_replicated_output() {
+  %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
+  %1 = "tf.opB"(%0) : (tensor<i1>) -> tensor<i1>
+  "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
+  return
+}
+
+// CHECK: [[REPLICATE:%.+]]:2 = tf_device.replicate
+// CHECK: "tf.opB"([[REPLICATE]]#0)
+
+// -----
+
+
 // Test cluster with missing `num_replicas` attribute.
 func @missing_num_replicas() {
   %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
@@ -567,20 +582,6 @@ func @mismatched_replicated_output() {
 // -----
 
 
-// Test cluster that should be replicated where its outputs do not lead to a
-// TPUReplicatedOutput.
-func @missing_replicated_output() {
-  // expected-error@+1 {{requires output of tf_device.cluster to lead to a 'tf.TPUReplicatedOutput' op}}
-  %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
-  %1 = "tf.opB"(%0) : (tensor<i1>) -> tensor<i1>
-  "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
-  return
-}
-
-
-// -----
-
-
 // Test unused TPUReplicatedInput that has more than one operand.
 func @leftover_replicated_input(%arg0: tensor<i1>) {
   %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (tensor<i1>, tensor<i1>) -> tensor<i1>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
index 4a2e957f5b0..46bc094e5ed 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
@@ -430,20 +430,24 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
   for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
     Value result = result_and_idx.value();
     int idx = result_and_idx.index();
-    for (auto& use : result.getUses()) {
-      Operation* def = use.getOwner();
-      if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def))
-        return cluster.emitError()
-               << "requires output of " << cluster.getOperationName()
-               << " to lead to a 'tf.TPUReplicatedOutput' op";
+    auto replicate_outputs = llvm::make_range(
+        std::next(replicate_op.result_begin(), idx * num_replicas),
+        std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
 
-      const int def_NumResults = def->getNumResults();
-      if (def_NumResults != num_replicas)
+    for (auto& use : llvm::make_early_inc_range(result.getUses())) {
+      Operation* def = use.getOwner();
+      if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) {
+        // If user is not a `tf.TPUReplicatedOutput`, simply forward the first
+        // replica output. Certain Graphs under V1 create `tf.Identity` users of
+        // replicated ops to pin the TPU computation for execution.
+        use.set(*replicate_outputs.begin());
+        continue;
+      }
+
+      const int def_num_results = def->getNumResults();
+      if (def_num_results != num_replicas)
         return def->emitOpError() << "requires " << num_replicas << " results";
 
-      auto replicate_outputs = llvm::make_range(
-          std::next(replicate_op.result_begin(), idx * num_replicas),
-          std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
       def->replaceAllUsesWith(replicate_outputs);
     }
   }