Update TPU cluster formation to forward first replicated output for non replicated consumers of such outputs.
Some V1 Graphs may have tf.Identity users of replicated outputs and these tf.Identity have no users of themselves, but the tf.Identity is not replicated, nor uses a value from a tf.TPUReplicatedOutput. These tf.Identity may be used to pin a TPU computation for execution. For other replicas that do not have consumers, they will execute with how replicate to islands lowering pins control dependencies later in the pipeline. PiperOrigin-RevId: 337575303 Change-Id: I332fe97bfde5482348e9f7955e0ce4e4512c5da6
This commit is contained in:
parent
f701ab5c69
commit
7ce3915a32
tensorflow/compiler/mlir/tensorflow
@ -515,6 +515,21 @@ func @multiple_replicated_interleaved(%arg0: !tf_res) {
|
||||
// -----
|
||||
|
||||
|
||||
// Test cluster that is replicated but has a non TPUReplicatedOutput consumer.
|
||||
// CHECK-LABEL: func @replicated_non_replicated_output
|
||||
func @replicated_non_replicated_output() {
|
||||
%0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
|
||||
%1 = "tf.opB"(%0) : (tensor<i1>) -> tensor<i1>
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: [[REPLICATE:%.+]]:2 = tf_device.replicate
|
||||
// CHECK: "tf.opB"([[REPLICATE]]#0)
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Test cluster with missing `num_replicas` attribute.
|
||||
func @missing_num_replicas() {
|
||||
%0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
|
||||
@ -567,20 +582,6 @@ func @mismatched_replicated_output() {
|
||||
// -----
|
||||
|
||||
|
||||
// Test cluster that should be replicated where its outputs do not lead to a
|
||||
// TPUReplicatedOutput.
|
||||
func @missing_replicated_output() {
|
||||
// expected-error@+1 {{requires output of tf_device.cluster to lead to a 'tf.TPUReplicatedOutput' op}}
|
||||
%0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
|
||||
%1 = "tf.opB"(%0) : (tensor<i1>) -> tensor<i1>
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Test unused TPUReplicatedInput that has more than one operand.
|
||||
func @leftover_replicated_input(%arg0: tensor<i1>) {
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
|
@ -430,20 +430,24 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
||||
for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
|
||||
Value result = result_and_idx.value();
|
||||
int idx = result_and_idx.index();
|
||||
for (auto& use : result.getUses()) {
|
||||
Operation* def = use.getOwner();
|
||||
if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def))
|
||||
return cluster.emitError()
|
||||
<< "requires output of " << cluster.getOperationName()
|
||||
<< " to lead to a 'tf.TPUReplicatedOutput' op";
|
||||
auto replicate_outputs = llvm::make_range(
|
||||
std::next(replicate_op.result_begin(), idx * num_replicas),
|
||||
std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
|
||||
|
||||
const int def_NumResults = def->getNumResults();
|
||||
if (def_NumResults != num_replicas)
|
||||
for (auto& use : llvm::make_early_inc_range(result.getUses())) {
|
||||
Operation* def = use.getOwner();
|
||||
if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) {
|
||||
// If user is not a `tf.TPUReplicatedOutput`, simply forward the first
|
||||
// replica output. Certain Graphs under V1 create `tf.Identity` users of
|
||||
// replicated ops to pin the TPU computation for execution.
|
||||
use.set(*replicate_outputs.begin());
|
||||
continue;
|
||||
}
|
||||
|
||||
const int def_num_results = def->getNumResults();
|
||||
if (def_num_results != num_replicas)
|
||||
return def->emitOpError() << "requires " << num_replicas << " results";
|
||||
|
||||
auto replicate_outputs = llvm::make_range(
|
||||
std::next(replicate_op.result_begin(), idx * num_replicas),
|
||||
std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
|
||||
def->replaceAllUsesWith(replicate_outputs);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user