diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index 978f6e74aa8..281e4baaa12 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -380,6 +380,106 @@ func @resource_before_cluster() { } +// Test cluster formation with ops with attached regions within a cluster. +// Nested op's that are moved should get their _tpu_replicate and device +// attributes cleared. +// CHECK-LABEL: func @cluster_ops_with_regions +func @cluster_ops_with_regions() { + %0 = "tf.opA"() ({ + %1 = "tf.opB"() {_tpu_replicate = "replicate", device = "device", name = "nameB"} : () -> (tensor) + }) {_tpu_replicate = "replicate", device = "device", name = "nameA"} : () -> tensor + "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: "tf.opA"() ( { +// CHECK-NEXT: "tf.opB" +// CHECK-NOT: _tpu_replicate = "replicate" +// CHECK-NOT: device = "device" +// CHECK-SAME: name = "nameB" +// CHECK: }) +// CHECK-NOT: _tpu_replicate = "replicate" +// CHECK-NOT: device = "device" +// CHECK: name = "nameA" +// CHECK: tf_device.return + +// A nested cluster op using result of another cluster op. In the below, opA and +// opB go in a cluster, and opD stays outside. +// CHECK-LABEL: func @cluster_nested_op_using_other_op +func @cluster_nested_op_using_other_op() { + %0 = "tf.opA"() { _tpu_replicate = "foo" } : () -> tensor + "tf.opB"() ({ + "tf.opC"(%0) : (tensor) -> () + }) { _tpu_replicate = "foo" } : () -> () + "tf.opD"(%0) : (tensor) -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: [[CLUSTER:%.*]] = "tf_device.cluster"() ( { +// CHECK: [[OPA:%.*]] = "tf.opA"() : () -> tensor +// CHECK: "tf.opB"() ( { +// CHECK: "tf.opC"([[OPA]]) +// CHECK: tf_device.return [[OPA]] +// CHECK: "tf.opD"([[CLUSTER]]) + +// Preceding user is using resource updated by a nested op. +!tf_res = type tensor<*x!tf.resource>> +// CHECK-LABEL: func @cluster_nested_op_updating_resource +func @cluster_nested_op_updating_resource() { + %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> !tf_res + + "tf.opA"() ({ + "tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor) -> () + "tf.terminator"() : () -> () + }) { _tpu_replicate = "foo" } : () -> () + "tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor) -> () + "tf.opB"() { _tpu_replicate = "foo" } : () -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: [[CONST:%.*]] = "tf.Const" +// CHECK: [[VAR:%.*]] = "tf.VarHandleOp" +// CHECK: "tf_device.cluster"() ( { +// CHECK: "tf.opA"() ( { +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) +// CHECK: }) +// CHECK: "tf.opB"() +// CHECK: tf_device.return +// CHECK: }) +// CHECK-SAME: _tpu_replicate = "foo" +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) + +// Preceding user is using resource updated by the cluster within a nested op. +// Resource is updated by a cluster op, and opA (not in cluster) is using the +// resource in a nested op. We expect opA to be after the cluster. +// CHECK-LABEL: func @cluster_nested_op_using_resource +func @cluster_nested_op_using_resource() { + %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> !tf_res + "tf.AssignAddVariableOp"(%1, %0) { _tpu_replicate = "foo" } : (!tf_res, tensor) -> () + "tf.opA"() ({ + "tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor) -> () + "tf.terminator"() : () -> () + }) : () -> () + "tf.opB"() { _tpu_replicate = "foo" } : () -> () + "tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> () + return +} + +// CHECK: [[CONST:%.*]] = "tf.Const" +// CHECK: [[VAR:%.*]] = "tf.VarHandleOp" +// CHECK: "tf_device.cluster"() ( { +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) +// CHECK: "tf.opB"() +// CHECK: tf_device.return +// CHECK: }) +// CHECK-SAME: _tpu_replicate = "foo" +// CHECK: "tf.opA"() ( { +// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]]) + // ----- @@ -407,18 +507,6 @@ func @bad_num_replicas() { // ----- -// Test that functions without TPUReplicateMetadata op are skipped without -// error -// CHECK-LABEL: func @missing_metadata_op -func @missing_metadata_op() { - // expected-warning@+1 {{TPUReplicateMetadata for associated '_tpu_replicate' attribute 'replicate' is missing}} - %0 = "tf.opA"() {_tpu_replicate = "replicate"} : () -> tensor - return -} - -// ----- - - // Test cluster with TPUReplicatedInput where the number of operands does not // match associated `num_replicas` attribute. func @mismatched_replicated_input(%arg0: tensor) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index 597ecfade84..c3f40154c79 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -71,9 +71,11 @@ constexpr char kBadTPUReplicateAttrMsg[] = using MetadataMap = llvm::SmallDenseMap; +// A set of operations in a cluster. +using ClusterOps = llvm::SmallSetVector; + // Mapping for `_tpu_replicate` attribute to ops of a cluster. -using ClusterMap = llvm::SmallDenseMap, 8>; +using ClusterMap = llvm::SmallDenseMap; struct TPUClusterFormation : public TF::PerFunctionAggregateAnalysisConsumerPass< @@ -91,42 +93,40 @@ struct TPUClusterFormation // attribute to its attributes and removes the ops. If multiple // TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error // will be returned. -LogicalResult CollectMetadata(Operation* op, MetadataMap* metadata_map) { - auto result = - op->walk([&](TF::TPUReplicateMetadataOp metadata_op) -> WalkResult { - MutableDictionaryAttr attrs = metadata_op.getAttrs(); +LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) { + // Just look at top-level operations in the block (not nested ones) + for (Operation& op : llvm::make_early_inc_range(*block)) { + auto metadata_op = dyn_cast(op); + if (!metadata_op) continue; - // Missing or bad `_tpu_replicate` attribute. - auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr); - if (!tpu_replicate_attr) - return metadata_op.emitError() << kBadTPUReplicateAttrMsg; + MutableDictionaryAttr attrs = metadata_op.getAttrs(); - auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast(); - if (!tpu_replicate_attr_str || - tpu_replicate_attr_str.getValue().empty()) - return metadata_op.emitError() << kBadTPUReplicateAttrMsg; + // Missing or bad `_tpu_replicate` attribute. + auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr); + if (!tpu_replicate_attr) + return metadata_op.emitError() << kBadTPUReplicateAttrMsg; - // Remove `name` attribute. - attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext())); + auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast(); + if (!tpu_replicate_attr_str || tpu_replicate_attr_str.getValue().empty()) + return metadata_op.emitError() << kBadTPUReplicateAttrMsg; - auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(), - std::move(attrs)); + // Remove `name` attribute. + attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext())); - // There are multiple TPUReplicateMetadata ops with the same - // `_tpu_replicate` attribute. - if (!it.second) { - return metadata_op.emitError() - << "multiple TPUReplicateMetadata ops with the same '" - << kTPUReplicateAttr << "' attribute '" - << tpu_replicate_attr_str.getValue() << "' found"; - } + auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(), + std::move(attrs)); - metadata_op.erase(); - return WalkResult::advance(); - }); - - // Return failure if the walk was interrupted. - return failure(result.wasInterrupted()); + // There are multiple TPUReplicateMetadata ops with the same + // `_tpu_replicate` attribute. + if (!it.second) { + return metadata_op.emitError() + << "multiple TPUReplicateMetadata ops with the same '" + << kTPUReplicateAttr << "' attribute '" + << tpu_replicate_attr_str.getValue() << "' found"; + } + metadata_op.erase(); + } + return success(); } // Collects and clusters ops with the same `_tpu_replicate` attribute. This will @@ -154,12 +154,12 @@ void CollectResourceIdsFromOp( op.walk([&](Operation* inner_op) { for (Value operand : TF::filter_resources(inner_op->getOperands())) { if (resource_alias_analysis.IsUnknownResource(operand)) continue; - auto ids = resource_alias_analysis.GetResourceUniqueIds(operand); + const auto& ids = resource_alias_analysis.GetResourceUniqueIds(operand); observed_resource_ids.insert(ids.begin(), ids.end()); } for (Value result : TF::filter_resources(inner_op->getResults())) { if (resource_alias_analysis.IsUnknownResource(result)) continue; - auto ids = resource_alias_analysis.GetResourceUniqueIds(result); + const auto& ids = resource_alias_analysis.GetResourceUniqueIds(result); observed_resource_ids.insert(ids.begin(), ids.end()); } }); @@ -168,13 +168,12 @@ void CollectResourceIdsFromOp( // Checks if an op should be moved after a cluster. There may be users of a // cluster interleaved among the cluster ops. bool ShouldMoveOpAfterCluster( - Block* block, Operation* op, - const llvm::SmallSetVector& cluster_ops, + Block* block, Operation* op, const ClusterOps& cluster_ops, const llvm::SmallSetVector& preceding_users, const TF::ResourceAliasAnalysis::Info& resource_alias_analysis, const llvm::SmallDenseSet& observed_resource_ids) { - auto result = op->walk([&](Operation* op) { - for (Value operand : op->getOperands()) { + auto result = op->walk([&](Operation* inner_op) { + for (Value operand : inner_op->getOperands()) { Operation* def = operand.getDefiningOp(); // Operands may not have a defining op (BlockArgument) or is from a // different block. @@ -188,7 +187,7 @@ bool ShouldMoveOpAfterCluster( } // Check for uses of any resource in or after cluster. - for (Value operand : TF::filter_resources(op->getOperands())) { + for (Value operand : TF::filter_resources(inner_op->getOperands())) { if (resource_alias_analysis.IsUnknownResource(operand)) continue; auto ids = resource_alias_analysis.GetResourceUniqueIds(operand); for (const auto& id : ids) @@ -208,13 +207,14 @@ bool ShouldMoveOpAfterCluster( // TODO(lyandy): Extend this to handle all side effecting ops while handling // transitive data dependencies. llvm::SmallSetVector CollectClusterPrecedingUsers( - Block* block, const llvm::SmallSetVector& cluster_ops, + Block* block, const ClusterOps& cluster_ops, const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { llvm::SmallSetVector preceding_users; llvm::SmallDenseSet observed_resource_ids; - for (Operation& op : llvm::make_range(Block::iterator(cluster_ops.front()), - Block::iterator(cluster_ops.back()))) { + auto front = Block::iterator(cluster_ops.front()); + auto back = Block::iterator(cluster_ops.back()); + for (Operation& op : llvm::make_range(front, back)) { if (cluster_ops.contains(&op)) { CollectResourceIdsFromOp(op, resource_alias_analysis, observed_resource_ids); @@ -236,7 +236,7 @@ llvm::SmallSetVector CollectClusterPrecedingUsers( // outside of the cluster (i.e. results of ops in the cluster are only consumed // by other ops in the cluster) are pruned. llvm::SmallVector CollectClusterResults( - Block* block, const llvm::SmallSetVector& cluster_ops) { + Block* block, const ClusterOps& cluster_ops) { llvm::SmallVector results; for (Operation* op : cluster_ops) { @@ -255,61 +255,52 @@ llvm::SmallVector CollectClusterResults( } // Creates a `tf_device.cluster` to wrap cluster ops. -tf_device::ClusterOp CreateOpForCluster(Operation* last_cluster_op, - llvm::ArrayRef results) { +tf_device::ClusterOp CreateClusterOp( + Block* block, const ClusterOps& cluster_ops, llvm::ArrayRef results, + llvm::ArrayRef preceding_users) { // `tf_device.cluster` will be placed at where the last op of the cluster is. + Operation* last_cluster_op = cluster_ops.back(); OpBuilder builder(last_cluster_op); llvm::SmallVector result_types; for (Value result : results) result_types.push_back(result.getType()); - auto cluster = builder.create(last_cluster_op->getLoc(), result_types); - cluster.body().push_back(new Block); + Block* body = new Block; + cluster.body().push_back(body); + + // Move cluster ops to the cluster body. Also remove `_tpu_replicate` and + // `device` attribute from ops in the cluster as that information will be + // present in the `tf_device.cluster`. Do this for all ops including nested + // ops. + for (Operation* cluster_op : cluster_ops) { + cluster_op->moveBefore(body, body->end()); + cluster_op->walk([&](Operation* inner_op) { + inner_op->removeAttr(kTPUReplicateAttr); + inner_op->removeAttr(kDeviceAttr); + }); + } // Add terminator. - builder.setInsertionPointToEnd(&cluster.GetBody()); + builder.setInsertionPointToEnd(body); builder.create(last_cluster_op->getLoc(), results); - return cluster; -} - -// Moves cluster ops to associated `tf_device.cluster` body. -void MoveClusterOpsToCluster( - tf_device::ClusterOp cluster, - const llvm::SmallSetVector& cluster_ops) { - MLIRContext* context = cluster.getContext(); - Operation* terminator = cluster.GetBody().getTerminator(); - - for (Operation* cluster_op : cluster_ops) { - // Remove `_tpu_replicate` and `device` attribute from ops in the cluster - // as that information will be present in the `tf_device.cluster`. - cluster_op->removeAttr(Identifier::get(kTPUReplicateAttr, context)); - cluster_op->removeAttr(Identifier::get(kDeviceAttr, context)); - cluster_op->moveBefore(terminator); - } -} - -// Replaces uses of cluster ops results outside of cluster with the associated -// `tf_device.cluster` results. -void UpdateClusterResultExternalUses(tf_device::ClusterOp cluster, - llvm::ArrayRef results) { - Block& cluster_block = cluster.GetBody(); + // Replaces uses of cluster ops results outside of cluster with the associated + // `tf_device.cluster` results. for (auto ret_vals : llvm::zip(results, cluster.getResults())) { Value old_ret = std::get<0>(ret_vals); Value new_ret = std::get<1>(ret_vals); - for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) - if (!cluster_block.findAncestorOpInBlock(*use.getOwner())) - use.set(new_ret); + for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) { + Operation* user = use.getOwner(); + if (!body->findAncestorOpInBlock(*user)) use.set(new_ret); + } } -} -// Moves users of cluster that are before the cluster to after the cluster. -void MovePrecedingClusterUsers(tf_device::ClusterOp cluster, - llvm::ArrayRef preceding_users) { + // Move users of cluster that are before the cluster to after the cluster. Operation* op_after_cluster = cluster.getOperation()->getNextNode(); for (Operation* user : preceding_users) user->moveBefore(op_after_cluster); + return cluster; } // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index` @@ -490,10 +481,29 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { // attribute `num_replicas` is greater than 1. // 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`. LogicalResult FormClustersInBlock( - Block* block, const MetadataMap& metadata_map, + Block* block, const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { + MetadataMap metadata_map; + LogicalResult result = CollectMetadata(block, &metadata_map); + if (failed(result)) return result; + + // If there is no TPUReplicateMetadata op in this block, process blocks in + // regions attached to the op's in the block. + if (metadata_map.empty()) { + for (Operation& op : *block) { + for (Region& region : op.getRegions()) { + if (!llvm::hasSingleElement(region)) + return op.emitOpError("Expected single block region"); + if (failed( + FormClustersInBlock(®ion.front(), resource_alias_analysis))) + return failure(); + } + } + return success(); + } + ClusterMap clusters; - LogicalResult result = CollectAndGroupClusterOps(block, &clusters); + result = CollectAndGroupClusterOps(block, &clusters); if (failed(result)) return result; for (const auto& cluster_metadata_and_ops : clusters) { @@ -518,14 +528,8 @@ LogicalResult FormClustersInBlock( llvm::SmallVector results = CollectClusterResults(block, cluster_ops); - tf_device::ClusterOp cluster = - CreateOpForCluster(cluster_ops.back(), results); - - MoveClusterOpsToCluster(cluster, cluster_ops); - - UpdateClusterResultExternalUses(cluster, results); - - MovePrecedingClusterUsers(cluster, preceding_users.getArrayRef()); + tf_device::ClusterOp cluster = CreateClusterOp( + block, cluster_ops, results, preceding_users.getArrayRef()); auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr); if (!num_replicas || !num_replicas.isa()) @@ -548,13 +552,13 @@ LogicalResult FormClustersInBlock( void TPUClusterFormation::runOnFunction( FuncOp func, const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { - MetadataMap metadata_map; - if (failed(CollectMetadata(func, &metadata_map))) return signalPassFailure(); + if (!llvm::hasSingleElement(func)) { + func.emitOpError("Expecting a single block function"); + return signalPassFailure(); + } - for (Block& block : func) - if (failed( - FormClustersInBlock(&block, metadata_map, resource_alias_analysis))) - return signalPassFailure(); + if (failed(FormClustersInBlock(&func.front(), resource_alias_analysis))) + return signalPassFailure(); // Remove TPUReplicatedInput and TPUReplicatedOutput nodes. auto remove_result = func.walk([&](Operation* op) {