diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index 18c76915284..3c2344be1e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -515,6 +515,21 @@ func @multiple_replicated_interleaved(%arg0: !tf_res) { // ----- +// Test cluster that is replicated but has a non TPUReplicatedOutput consumer. +// CHECK-LABEL: func @replicated_non_replicated_output +func @replicated_non_replicated_output() { + %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1> + %1 = "tf.opB"(%0) : (tensor<i1>) -> tensor<i1> + "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () + return +} + +// CHECK: [[REPLICATE:%.+]]:2 = tf_device.replicate +// CHECK: "tf.opB"([[REPLICATE]]#0) + +// ----- + + // Test cluster with missing `num_replicas` attribute. func @missing_num_replicas() { %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1> @@ -567,20 +582,6 @@ func @mismatched_replicated_output() { // ----- -// Test cluster that should be replicated where its outputs do not lead to a -// TPUReplicatedOutput. -func @missing_replicated_output() { - // expected-error@+1 {{requires output of tf_device.cluster to lead to a 'tf.TPUReplicatedOutput' op}} - %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1> - %1 = "tf.opB"(%0) : (tensor<i1>) -> tensor<i1> - "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () - return -} - - -// ----- - - // Test unused TPUReplicatedInput that has more than one operand. func @leftover_replicated_input(%arg0: tensor<i1>) { %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (tensor<i1>, tensor<i1>) -> tensor<i1> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index 4a2e957f5b0..46bc094e5ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -430,20 +430,24 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { for (auto result_and_idx : llvm::enumerate(cluster.getResults())) { Value result = result_and_idx.value(); int idx = result_and_idx.index(); - for (auto& use : result.getUses()) { - Operation* def = use.getOwner(); - if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def)) - return cluster.emitError() - << "requires output of " << cluster.getOperationName() - << " to lead to a 'tf.TPUReplicatedOutput' op"; + auto replicate_outputs = llvm::make_range( + std::next(replicate_op.result_begin(), idx * num_replicas), + std::next(replicate_op.result_begin(), (idx + 1) * num_replicas)); - const int def_NumResults = def->getNumResults(); - if (def_NumResults != num_replicas) + for (auto& use : llvm::make_early_inc_range(result.getUses())) { + Operation* def = use.getOwner(); + if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) { + // If user is not a `tf.TPUReplicatedOutput`, simply forward the first + // replica output. Certain Graphs under V1 create `tf.Identity` users of + // replicated ops to pin the TPU computation for execution. + use.set(*replicate_outputs.begin()); + continue; + } + + const int def_num_results = def->getNumResults(); + if (def_num_results != num_replicas) return def->emitOpError() << "requires " << num_replicas << " results"; - auto replicate_outputs = llvm::make_range( - std::next(replicate_op.result_begin(), idx * num_replicas), - std::next(replicate_op.result_begin(), (idx + 1) * num_replicas)); def->replaceAllUsesWith(replicate_outputs); } }