Update TPU cluster formation to forward first replicated output for non replicated consumers of such outputs.

Some V1 Graphs may have tf.Identity users of replicated outputs and these tf.Identity have no users of themselves, but the tf.Identity is not replicated, nor uses a value from a tf.TPUReplicatedOutput. These tf.Identity may be used to pin a TPU computation for execution. For other replicas that do not have consumers, they will execute with how replicate to islands lowering pins control dependencies later in the pipeline.

PiperOrigin-RevId: 337575303
Change-Id: I332fe97bfde5482348e9f7955e0ce4e4512c5da6
This commit is contained in:
Andy Ly 2020-10-16 14:11:36 -07:00 committed by TensorFlower Gardener
parent f701ab5c69
commit 7ce3915a32
2 changed files with 30 additions and 25 deletions
tensorflow/compiler/mlir/tensorflow

View File

@ -515,6 +515,21 @@ func @multiple_replicated_interleaved(%arg0: !tf_res) {
// -----
// Test cluster that is replicated but has a non TPUReplicatedOutput consumer.
// CHECK-LABEL: func @replicated_non_replicated_output
func @replicated_non_replicated_output() {
%0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
%1 = "tf.opB"(%0) : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// CHECK: [[REPLICATE:%.+]]:2 = tf_device.replicate
// CHECK: "tf.opB"([[REPLICATE]]#0)
// -----
// Test cluster with missing `num_replicas` attribute.
func @missing_num_replicas() {
%0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
@ -567,20 +582,6 @@ func @mismatched_replicated_output() {
// -----
// Test cluster that should be replicated where its outputs do not lead to a
// TPUReplicatedOutput.
func @missing_replicated_output() {
// expected-error@+1 {{requires output of tf_device.cluster to lead to a 'tf.TPUReplicatedOutput' op}}
%0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
%1 = "tf.opB"(%0) : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// -----
// Test unused TPUReplicatedInput that has more than one operand.
func @leftover_replicated_input(%arg0: tensor<i1>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (tensor<i1>, tensor<i1>) -> tensor<i1>

View File

@ -430,20 +430,24 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
Value result = result_and_idx.value();
int idx = result_and_idx.index();
for (auto& use : result.getUses()) {
Operation* def = use.getOwner();
if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def))
return cluster.emitError()
<< "requires output of " << cluster.getOperationName()
<< " to lead to a 'tf.TPUReplicatedOutput' op";
auto replicate_outputs = llvm::make_range(
std::next(replicate_op.result_begin(), idx * num_replicas),
std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
const int def_NumResults = def->getNumResults();
if (def_NumResults != num_replicas)
for (auto& use : llvm::make_early_inc_range(result.getUses())) {
Operation* def = use.getOwner();
if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) {
// If user is not a `tf.TPUReplicatedOutput`, simply forward the first
// replica output. Certain Graphs under V1 create `tf.Identity` users of
// replicated ops to pin the TPU computation for execution.
use.set(*replicate_outputs.begin());
continue;
}
const int def_num_results = def->getNumResults();
if (def_num_results != num_replicas)
return def->emitOpError() << "requires " << num_replicas << " results";
auto replicate_outputs = llvm::make_range(
std::next(replicate_op.result_begin(), idx * num_replicas),
std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
def->replaceAllUsesWith(replicate_outputs);
}
}