[MLIR] Fix TPU cluster formation to work with region based control flow
- After converting to region based control flow, the TPU replicate metadata op could be found in one of the regions attached to an op within a function and not at the top level. Fix TPU cluster formation to support this. - When looking for the replicate metadata op, scan just the top level operations in the block. If none found, look for the op in blocks of regions attached to the top-level ops. - Also combine several small functions that aid in building the cluster op into a single one so that its easier to follow. PiperOrigin-RevId: 331883677 Change-Id: Ic5858ee15282270fae5dca63b019c6616a4036af
This commit is contained in:
parent
1b9addc305
commit
96717822df
@ -380,6 +380,106 @@ func @resource_before_cluster() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Test cluster formation with ops with attached regions within a cluster.
|
||||||
|
// Nested op's that are moved should get their _tpu_replicate and device
|
||||||
|
// attributes cleared.
|
||||||
|
// CHECK-LABEL: func @cluster_ops_with_regions
|
||||||
|
func @cluster_ops_with_regions() {
|
||||||
|
%0 = "tf.opA"() ({
|
||||||
|
%1 = "tf.opB"() {_tpu_replicate = "replicate", device = "device", name = "nameB"} : () -> (tensor<i32>)
|
||||||
|
}) {_tpu_replicate = "replicate", device = "device", name = "nameA"} : () -> tensor<i1>
|
||||||
|
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: "tf.opA"() ( {
|
||||||
|
// CHECK-NEXT: "tf.opB"
|
||||||
|
// CHECK-NOT: _tpu_replicate = "replicate"
|
||||||
|
// CHECK-NOT: device = "device"
|
||||||
|
// CHECK-SAME: name = "nameB"
|
||||||
|
// CHECK: })
|
||||||
|
// CHECK-NOT: _tpu_replicate = "replicate"
|
||||||
|
// CHECK-NOT: device = "device"
|
||||||
|
// CHECK: name = "nameA"
|
||||||
|
// CHECK: tf_device.return
|
||||||
|
|
||||||
|
// A nested cluster op using result of another cluster op. In the below, opA and
|
||||||
|
// opB go in a cluster, and opD stays outside.
|
||||||
|
// CHECK-LABEL: func @cluster_nested_op_using_other_op
|
||||||
|
func @cluster_nested_op_using_other_op() {
|
||||||
|
%0 = "tf.opA"() { _tpu_replicate = "foo" } : () -> tensor<i32>
|
||||||
|
"tf.opB"() ({
|
||||||
|
"tf.opC"(%0) : (tensor<i32>) -> ()
|
||||||
|
}) { _tpu_replicate = "foo" } : () -> ()
|
||||||
|
"tf.opD"(%0) : (tensor<i32>) -> ()
|
||||||
|
"tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: [[CLUSTER:%.*]] = "tf_device.cluster"() ( {
|
||||||
|
// CHECK: [[OPA:%.*]] = "tf.opA"() : () -> tensor<i32>
|
||||||
|
// CHECK: "tf.opB"() ( {
|
||||||
|
// CHECK: "tf.opC"([[OPA]])
|
||||||
|
// CHECK: tf_device.return [[OPA]]
|
||||||
|
// CHECK: "tf.opD"([[CLUSTER]])
|
||||||
|
|
||||||
|
// Preceding user is using resource updated by a nested op.
|
||||||
|
!tf_res = type tensor<*x!tf.resource<tensor<f32>>>
|
||||||
|
// CHECK-LABEL: func @cluster_nested_op_updating_resource
|
||||||
|
func @cluster_nested_op_updating_resource() {
|
||||||
|
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
%1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> !tf_res
|
||||||
|
|
||||||
|
"tf.opA"() ({
|
||||||
|
"tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor<f32>) -> ()
|
||||||
|
"tf.terminator"() : () -> ()
|
||||||
|
}) { _tpu_replicate = "foo" } : () -> ()
|
||||||
|
"tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor<f32>) -> ()
|
||||||
|
"tf.opB"() { _tpu_replicate = "foo" } : () -> ()
|
||||||
|
"tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: [[CONST:%.*]] = "tf.Const"
|
||||||
|
// CHECK: [[VAR:%.*]] = "tf.VarHandleOp"
|
||||||
|
// CHECK: "tf_device.cluster"() ( {
|
||||||
|
// CHECK: "tf.opA"() ( {
|
||||||
|
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
|
||||||
|
// CHECK: })
|
||||||
|
// CHECK: "tf.opB"()
|
||||||
|
// CHECK: tf_device.return
|
||||||
|
// CHECK: })
|
||||||
|
// CHECK-SAME: _tpu_replicate = "foo"
|
||||||
|
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
|
||||||
|
|
||||||
|
// Preceding user is using resource updated by the cluster within a nested op.
|
||||||
|
// Resource is updated by a cluster op, and opA (not in cluster) is using the
|
||||||
|
// resource in a nested op. We expect opA to be after the cluster.
|
||||||
|
// CHECK-LABEL: func @cluster_nested_op_using_resource
|
||||||
|
func @cluster_nested_op_using_resource() {
|
||||||
|
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
%1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> !tf_res
|
||||||
|
"tf.AssignAddVariableOp"(%1, %0) { _tpu_replicate = "foo" } : (!tf_res, tensor<f32>) -> ()
|
||||||
|
"tf.opA"() ({
|
||||||
|
"tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor<f32>) -> ()
|
||||||
|
"tf.terminator"() : () -> ()
|
||||||
|
}) : () -> ()
|
||||||
|
"tf.opB"() { _tpu_replicate = "foo" } : () -> ()
|
||||||
|
"tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: [[CONST:%.*]] = "tf.Const"
|
||||||
|
// CHECK: [[VAR:%.*]] = "tf.VarHandleOp"
|
||||||
|
// CHECK: "tf_device.cluster"() ( {
|
||||||
|
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
|
||||||
|
// CHECK: "tf.opB"()
|
||||||
|
// CHECK: tf_device.return
|
||||||
|
// CHECK: })
|
||||||
|
// CHECK-SAME: _tpu_replicate = "foo"
|
||||||
|
// CHECK: "tf.opA"() ( {
|
||||||
|
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
||||||
@ -407,18 +507,6 @@ func @bad_num_replicas() {
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
||||||
// Test that functions without TPUReplicateMetadata op are skipped without
|
|
||||||
// error
|
|
||||||
// CHECK-LABEL: func @missing_metadata_op
|
|
||||||
func @missing_metadata_op() {
|
|
||||||
// expected-warning@+1 {{TPUReplicateMetadata for associated '_tpu_replicate' attribute 'replicate' is missing}}
|
|
||||||
%0 = "tf.opA"() {_tpu_replicate = "replicate"} : () -> tensor<i1>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
|
|
||||||
// Test cluster with TPUReplicatedInput where the number of operands does not
|
// Test cluster with TPUReplicatedInput where the number of operands does not
|
||||||
// match associated `num_replicas` attribute.
|
// match associated `num_replicas` attribute.
|
||||||
func @mismatched_replicated_input(%arg0: tensor<i1>) {
|
func @mismatched_replicated_input(%arg0: tensor<i1>) {
|
||||||
|
@ -71,9 +71,11 @@ constexpr char kBadTPUReplicateAttrMsg[] =
|
|||||||
using MetadataMap =
|
using MetadataMap =
|
||||||
llvm::SmallDenseMap<llvm::StringRef, MutableDictionaryAttr, 8>;
|
llvm::SmallDenseMap<llvm::StringRef, MutableDictionaryAttr, 8>;
|
||||||
|
|
||||||
|
// A set of operations in a cluster.
|
||||||
|
using ClusterOps = llvm::SmallSetVector<Operation*, 8>;
|
||||||
|
|
||||||
// Mapping for `_tpu_replicate` attribute to ops of a cluster.
|
// Mapping for `_tpu_replicate` attribute to ops of a cluster.
|
||||||
using ClusterMap = llvm::SmallDenseMap<llvm::StringRef,
|
using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, ClusterOps, 8>;
|
||||||
llvm::SmallSetVector<Operation*, 8>, 8>;
|
|
||||||
|
|
||||||
struct TPUClusterFormation
|
struct TPUClusterFormation
|
||||||
: public TF::PerFunctionAggregateAnalysisConsumerPass<
|
: public TF::PerFunctionAggregateAnalysisConsumerPass<
|
||||||
@ -91,42 +93,40 @@ struct TPUClusterFormation
|
|||||||
// attribute to its attributes and removes the ops. If multiple
|
// attribute to its attributes and removes the ops. If multiple
|
||||||
// TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error
|
// TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error
|
||||||
// will be returned.
|
// will be returned.
|
||||||
LogicalResult CollectMetadata(Operation* op, MetadataMap* metadata_map) {
|
LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) {
|
||||||
auto result =
|
// Just look at top-level operations in the block (not nested ones)
|
||||||
op->walk([&](TF::TPUReplicateMetadataOp metadata_op) -> WalkResult {
|
for (Operation& op : llvm::make_early_inc_range(*block)) {
|
||||||
MutableDictionaryAttr attrs = metadata_op.getAttrs();
|
auto metadata_op = dyn_cast<TF::TPUReplicateMetadataOp>(op);
|
||||||
|
if (!metadata_op) continue;
|
||||||
|
|
||||||
// Missing or bad `_tpu_replicate` attribute.
|
MutableDictionaryAttr attrs = metadata_op.getAttrs();
|
||||||
auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr);
|
|
||||||
if (!tpu_replicate_attr)
|
|
||||||
return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
|
|
||||||
|
|
||||||
auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast<StringAttr>();
|
// Missing or bad `_tpu_replicate` attribute.
|
||||||
if (!tpu_replicate_attr_str ||
|
auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr);
|
||||||
tpu_replicate_attr_str.getValue().empty())
|
if (!tpu_replicate_attr)
|
||||||
return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
|
return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
|
||||||
|
|
||||||
// Remove `name` attribute.
|
auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast<StringAttr>();
|
||||||
attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext()));
|
if (!tpu_replicate_attr_str || tpu_replicate_attr_str.getValue().empty())
|
||||||
|
return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
|
||||||
|
|
||||||
auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(),
|
// Remove `name` attribute.
|
||||||
std::move(attrs));
|
attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext()));
|
||||||
|
|
||||||
// There are multiple TPUReplicateMetadata ops with the same
|
auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(),
|
||||||
// `_tpu_replicate` attribute.
|
std::move(attrs));
|
||||||
if (!it.second) {
|
|
||||||
return metadata_op.emitError()
|
|
||||||
<< "multiple TPUReplicateMetadata ops with the same '"
|
|
||||||
<< kTPUReplicateAttr << "' attribute '"
|
|
||||||
<< tpu_replicate_attr_str.getValue() << "' found";
|
|
||||||
}
|
|
||||||
|
|
||||||
metadata_op.erase();
|
// There are multiple TPUReplicateMetadata ops with the same
|
||||||
return WalkResult::advance();
|
// `_tpu_replicate` attribute.
|
||||||
});
|
if (!it.second) {
|
||||||
|
return metadata_op.emitError()
|
||||||
// Return failure if the walk was interrupted.
|
<< "multiple TPUReplicateMetadata ops with the same '"
|
||||||
return failure(result.wasInterrupted());
|
<< kTPUReplicateAttr << "' attribute '"
|
||||||
|
<< tpu_replicate_attr_str.getValue() << "' found";
|
||||||
|
}
|
||||||
|
metadata_op.erase();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collects and clusters ops with the same `_tpu_replicate` attribute. This will
|
// Collects and clusters ops with the same `_tpu_replicate` attribute. This will
|
||||||
@ -154,12 +154,12 @@ void CollectResourceIdsFromOp(
|
|||||||
op.walk([&](Operation* inner_op) {
|
op.walk([&](Operation* inner_op) {
|
||||||
for (Value operand : TF::filter_resources(inner_op->getOperands())) {
|
for (Value operand : TF::filter_resources(inner_op->getOperands())) {
|
||||||
if (resource_alias_analysis.IsUnknownResource(operand)) continue;
|
if (resource_alias_analysis.IsUnknownResource(operand)) continue;
|
||||||
auto ids = resource_alias_analysis.GetResourceUniqueIds(operand);
|
const auto& ids = resource_alias_analysis.GetResourceUniqueIds(operand);
|
||||||
observed_resource_ids.insert(ids.begin(), ids.end());
|
observed_resource_ids.insert(ids.begin(), ids.end());
|
||||||
}
|
}
|
||||||
for (Value result : TF::filter_resources(inner_op->getResults())) {
|
for (Value result : TF::filter_resources(inner_op->getResults())) {
|
||||||
if (resource_alias_analysis.IsUnknownResource(result)) continue;
|
if (resource_alias_analysis.IsUnknownResource(result)) continue;
|
||||||
auto ids = resource_alias_analysis.GetResourceUniqueIds(result);
|
const auto& ids = resource_alias_analysis.GetResourceUniqueIds(result);
|
||||||
observed_resource_ids.insert(ids.begin(), ids.end());
|
observed_resource_ids.insert(ids.begin(), ids.end());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -168,13 +168,12 @@ void CollectResourceIdsFromOp(
|
|||||||
// Checks if an op should be moved after a cluster. There may be users of a
|
// Checks if an op should be moved after a cluster. There may be users of a
|
||||||
// cluster interleaved among the cluster ops.
|
// cluster interleaved among the cluster ops.
|
||||||
bool ShouldMoveOpAfterCluster(
|
bool ShouldMoveOpAfterCluster(
|
||||||
Block* block, Operation* op,
|
Block* block, Operation* op, const ClusterOps& cluster_ops,
|
||||||
const llvm::SmallSetVector<Operation*, 8>& cluster_ops,
|
|
||||||
const llvm::SmallSetVector<Operation*, 8>& preceding_users,
|
const llvm::SmallSetVector<Operation*, 8>& preceding_users,
|
||||||
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis,
|
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis,
|
||||||
const llvm::SmallDenseSet<int64_t>& observed_resource_ids) {
|
const llvm::SmallDenseSet<int64_t>& observed_resource_ids) {
|
||||||
auto result = op->walk([&](Operation* op) {
|
auto result = op->walk([&](Operation* inner_op) {
|
||||||
for (Value operand : op->getOperands()) {
|
for (Value operand : inner_op->getOperands()) {
|
||||||
Operation* def = operand.getDefiningOp();
|
Operation* def = operand.getDefiningOp();
|
||||||
// Operands may not have a defining op (BlockArgument) or is from a
|
// Operands may not have a defining op (BlockArgument) or is from a
|
||||||
// different block.
|
// different block.
|
||||||
@ -188,7 +187,7 @@ bool ShouldMoveOpAfterCluster(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for uses of any resource in or after cluster.
|
// Check for uses of any resource in or after cluster.
|
||||||
for (Value operand : TF::filter_resources(op->getOperands())) {
|
for (Value operand : TF::filter_resources(inner_op->getOperands())) {
|
||||||
if (resource_alias_analysis.IsUnknownResource(operand)) continue;
|
if (resource_alias_analysis.IsUnknownResource(operand)) continue;
|
||||||
auto ids = resource_alias_analysis.GetResourceUniqueIds(operand);
|
auto ids = resource_alias_analysis.GetResourceUniqueIds(operand);
|
||||||
for (const auto& id : ids)
|
for (const auto& id : ids)
|
||||||
@ -208,13 +207,14 @@ bool ShouldMoveOpAfterCluster(
|
|||||||
// TODO(lyandy): Extend this to handle all side effecting ops while handling
|
// TODO(lyandy): Extend this to handle all side effecting ops while handling
|
||||||
// transitive data dependencies.
|
// transitive data dependencies.
|
||||||
llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
|
llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
|
||||||
Block* block, const llvm::SmallSetVector<Operation*, 8>& cluster_ops,
|
Block* block, const ClusterOps& cluster_ops,
|
||||||
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
|
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
|
||||||
llvm::SmallSetVector<Operation*, 8> preceding_users;
|
llvm::SmallSetVector<Operation*, 8> preceding_users;
|
||||||
llvm::SmallDenseSet<int64_t> observed_resource_ids;
|
llvm::SmallDenseSet<int64_t> observed_resource_ids;
|
||||||
|
|
||||||
for (Operation& op : llvm::make_range(Block::iterator(cluster_ops.front()),
|
auto front = Block::iterator(cluster_ops.front());
|
||||||
Block::iterator(cluster_ops.back()))) {
|
auto back = Block::iterator(cluster_ops.back());
|
||||||
|
for (Operation& op : llvm::make_range(front, back)) {
|
||||||
if (cluster_ops.contains(&op)) {
|
if (cluster_ops.contains(&op)) {
|
||||||
CollectResourceIdsFromOp(op, resource_alias_analysis,
|
CollectResourceIdsFromOp(op, resource_alias_analysis,
|
||||||
observed_resource_ids);
|
observed_resource_ids);
|
||||||
@ -236,7 +236,7 @@ llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
|
|||||||
// outside of the cluster (i.e. results of ops in the cluster are only consumed
|
// outside of the cluster (i.e. results of ops in the cluster are only consumed
|
||||||
// by other ops in the cluster) are pruned.
|
// by other ops in the cluster) are pruned.
|
||||||
llvm::SmallVector<Value, 8> CollectClusterResults(
|
llvm::SmallVector<Value, 8> CollectClusterResults(
|
||||||
Block* block, const llvm::SmallSetVector<Operation*, 8>& cluster_ops) {
|
Block* block, const ClusterOps& cluster_ops) {
|
||||||
llvm::SmallVector<Value, 8> results;
|
llvm::SmallVector<Value, 8> results;
|
||||||
|
|
||||||
for (Operation* op : cluster_ops) {
|
for (Operation* op : cluster_ops) {
|
||||||
@ -255,61 +255,52 @@ llvm::SmallVector<Value, 8> CollectClusterResults(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Creates a `tf_device.cluster` to wrap cluster ops.
|
// Creates a `tf_device.cluster` to wrap cluster ops.
|
||||||
tf_device::ClusterOp CreateOpForCluster(Operation* last_cluster_op,
|
tf_device::ClusterOp CreateClusterOp(
|
||||||
llvm::ArrayRef<Value> results) {
|
Block* block, const ClusterOps& cluster_ops, llvm::ArrayRef<Value> results,
|
||||||
|
llvm::ArrayRef<Operation*> preceding_users) {
|
||||||
// `tf_device.cluster` will be placed at where the last op of the cluster is.
|
// `tf_device.cluster` will be placed at where the last op of the cluster is.
|
||||||
|
Operation* last_cluster_op = cluster_ops.back();
|
||||||
OpBuilder builder(last_cluster_op);
|
OpBuilder builder(last_cluster_op);
|
||||||
|
|
||||||
llvm::SmallVector<Type, 8> result_types;
|
llvm::SmallVector<Type, 8> result_types;
|
||||||
for (Value result : results) result_types.push_back(result.getType());
|
for (Value result : results) result_types.push_back(result.getType());
|
||||||
|
|
||||||
auto cluster = builder.create<tf_device::ClusterOp>(last_cluster_op->getLoc(),
|
auto cluster = builder.create<tf_device::ClusterOp>(last_cluster_op->getLoc(),
|
||||||
result_types);
|
result_types);
|
||||||
|
|
||||||
cluster.body().push_back(new Block);
|
Block* body = new Block;
|
||||||
|
cluster.body().push_back(body);
|
||||||
|
|
||||||
|
// Move cluster ops to the cluster body. Also remove `_tpu_replicate` and
|
||||||
|
// `device` attribute from ops in the cluster as that information will be
|
||||||
|
// present in the `tf_device.cluster`. Do this for all ops including nested
|
||||||
|
// ops.
|
||||||
|
for (Operation* cluster_op : cluster_ops) {
|
||||||
|
cluster_op->moveBefore(body, body->end());
|
||||||
|
cluster_op->walk([&](Operation* inner_op) {
|
||||||
|
inner_op->removeAttr(kTPUReplicateAttr);
|
||||||
|
inner_op->removeAttr(kDeviceAttr);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Add terminator.
|
// Add terminator.
|
||||||
builder.setInsertionPointToEnd(&cluster.GetBody());
|
builder.setInsertionPointToEnd(body);
|
||||||
builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
|
builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
|
||||||
|
|
||||||
return cluster;
|
// Replaces uses of cluster ops results outside of cluster with the associated
|
||||||
}
|
// `tf_device.cluster` results.
|
||||||
|
|
||||||
// Moves cluster ops to associated `tf_device.cluster` body.
|
|
||||||
void MoveClusterOpsToCluster(
|
|
||||||
tf_device::ClusterOp cluster,
|
|
||||||
const llvm::SmallSetVector<Operation*, 8>& cluster_ops) {
|
|
||||||
MLIRContext* context = cluster.getContext();
|
|
||||||
Operation* terminator = cluster.GetBody().getTerminator();
|
|
||||||
|
|
||||||
for (Operation* cluster_op : cluster_ops) {
|
|
||||||
// Remove `_tpu_replicate` and `device` attribute from ops in the cluster
|
|
||||||
// as that information will be present in the `tf_device.cluster`.
|
|
||||||
cluster_op->removeAttr(Identifier::get(kTPUReplicateAttr, context));
|
|
||||||
cluster_op->removeAttr(Identifier::get(kDeviceAttr, context));
|
|
||||||
cluster_op->moveBefore(terminator);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replaces uses of cluster ops results outside of cluster with the associated
|
|
||||||
// `tf_device.cluster` results.
|
|
||||||
void UpdateClusterResultExternalUses(tf_device::ClusterOp cluster,
|
|
||||||
llvm::ArrayRef<Value> results) {
|
|
||||||
Block& cluster_block = cluster.GetBody();
|
|
||||||
for (auto ret_vals : llvm::zip(results, cluster.getResults())) {
|
for (auto ret_vals : llvm::zip(results, cluster.getResults())) {
|
||||||
Value old_ret = std::get<0>(ret_vals);
|
Value old_ret = std::get<0>(ret_vals);
|
||||||
Value new_ret = std::get<1>(ret_vals);
|
Value new_ret = std::get<1>(ret_vals);
|
||||||
for (auto& use : llvm::make_early_inc_range(old_ret.getUses()))
|
for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) {
|
||||||
if (!cluster_block.findAncestorOpInBlock(*use.getOwner()))
|
Operation* user = use.getOwner();
|
||||||
use.set(new_ret);
|
if (!body->findAncestorOpInBlock(*user)) use.set(new_ret);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Moves users of cluster that are before the cluster to after the cluster.
|
// Move users of cluster that are before the cluster to after the cluster.
|
||||||
void MovePrecedingClusterUsers(tf_device::ClusterOp cluster,
|
|
||||||
llvm::ArrayRef<Operation*> preceding_users) {
|
|
||||||
Operation* op_after_cluster = cluster.getOperation()->getNextNode();
|
Operation* op_after_cluster = cluster.getOperation()->getNextNode();
|
||||||
for (Operation* user : preceding_users) user->moveBefore(op_after_cluster);
|
for (Operation* user : preceding_users) user->moveBefore(op_after_cluster);
|
||||||
|
return cluster;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
|
// Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
|
||||||
@ -490,10 +481,29 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
|||||||
// attribute `num_replicas` is greater than 1.
|
// attribute `num_replicas` is greater than 1.
|
||||||
// 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`.
|
// 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`.
|
||||||
LogicalResult FormClustersInBlock(
|
LogicalResult FormClustersInBlock(
|
||||||
Block* block, const MetadataMap& metadata_map,
|
Block* block,
|
||||||
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
|
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
|
||||||
|
MetadataMap metadata_map;
|
||||||
|
LogicalResult result = CollectMetadata(block, &metadata_map);
|
||||||
|
if (failed(result)) return result;
|
||||||
|
|
||||||
|
// If there is no TPUReplicateMetadata op in this block, process blocks in
|
||||||
|
// regions attached to the op's in the block.
|
||||||
|
if (metadata_map.empty()) {
|
||||||
|
for (Operation& op : *block) {
|
||||||
|
for (Region& region : op.getRegions()) {
|
||||||
|
if (!llvm::hasSingleElement(region))
|
||||||
|
return op.emitOpError("Expected single block region");
|
||||||
|
if (failed(
|
||||||
|
FormClustersInBlock(®ion.front(), resource_alias_analysis)))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
ClusterMap clusters;
|
ClusterMap clusters;
|
||||||
LogicalResult result = CollectAndGroupClusterOps(block, &clusters);
|
result = CollectAndGroupClusterOps(block, &clusters);
|
||||||
if (failed(result)) return result;
|
if (failed(result)) return result;
|
||||||
|
|
||||||
for (const auto& cluster_metadata_and_ops : clusters) {
|
for (const auto& cluster_metadata_and_ops : clusters) {
|
||||||
@ -518,14 +528,8 @@ LogicalResult FormClustersInBlock(
|
|||||||
llvm::SmallVector<Value, 8> results =
|
llvm::SmallVector<Value, 8> results =
|
||||||
CollectClusterResults(block, cluster_ops);
|
CollectClusterResults(block, cluster_ops);
|
||||||
|
|
||||||
tf_device::ClusterOp cluster =
|
tf_device::ClusterOp cluster = CreateClusterOp(
|
||||||
CreateOpForCluster(cluster_ops.back(), results);
|
block, cluster_ops, results, preceding_users.getArrayRef());
|
||||||
|
|
||||||
MoveClusterOpsToCluster(cluster, cluster_ops);
|
|
||||||
|
|
||||||
UpdateClusterResultExternalUses(cluster, results);
|
|
||||||
|
|
||||||
MovePrecedingClusterUsers(cluster, preceding_users.getArrayRef());
|
|
||||||
|
|
||||||
auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr);
|
auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr);
|
||||||
if (!num_replicas || !num_replicas.isa<mlir::IntegerAttr>())
|
if (!num_replicas || !num_replicas.isa<mlir::IntegerAttr>())
|
||||||
@ -548,13 +552,13 @@ LogicalResult FormClustersInBlock(
|
|||||||
void TPUClusterFormation::runOnFunction(
|
void TPUClusterFormation::runOnFunction(
|
||||||
FuncOp func,
|
FuncOp func,
|
||||||
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
|
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
|
||||||
MetadataMap metadata_map;
|
if (!llvm::hasSingleElement(func)) {
|
||||||
if (failed(CollectMetadata(func, &metadata_map))) return signalPassFailure();
|
func.emitOpError("Expecting a single block function");
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
for (Block& block : func)
|
if (failed(FormClustersInBlock(&func.front(), resource_alias_analysis)))
|
||||||
if (failed(
|
return signalPassFailure();
|
||||||
FormClustersInBlock(&block, metadata_map, resource_alias_analysis)))
|
|
||||||
return signalPassFailure();
|
|
||||||
|
|
||||||
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
|
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
|
||||||
auto remove_result = func.walk([&](Operation* op) {
|
auto remove_result = func.walk([&](Operation* op) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user