[MLIR] Fix TPU cluster formation to work with region based control flow
- After converting to region based control flow, the TPU replicate metadata op could be found in one of the regions attached to an op within a function and not at the top level. Fix TPU cluster formation to support this. - When looking for the replicate metadata op, scan just the top level operations in the block. If none found, look for the op in blocks of regions attached to the top-level ops. - Also combine several small functions that aid in building the cluster op into a single one so that its easier to follow. PiperOrigin-RevId: 331883677 Change-Id: Ic5858ee15282270fae5dca63b019c6616a4036af
This commit is contained in:
parent
1b9addc305
commit
96717822df
@ -380,6 +380,106 @@ func @resource_before_cluster() {
|
||||
}
|
||||
|
||||
|
||||
// Test cluster formation with ops with attached regions within a cluster.
|
||||
// Nested op's that are moved should get their _tpu_replicate and device
|
||||
// attributes cleared.
|
||||
// CHECK-LABEL: func @cluster_ops_with_regions
|
||||
func @cluster_ops_with_regions() {
|
||||
%0 = "tf.opA"() ({
|
||||
%1 = "tf.opB"() {_tpu_replicate = "replicate", device = "device", name = "nameB"} : () -> (tensor<i32>)
|
||||
}) {_tpu_replicate = "replicate", device = "device", name = "nameA"} : () -> tensor<i1>
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: "tf.opA"() ( {
|
||||
// CHECK-NEXT: "tf.opB"
|
||||
// CHECK-NOT: _tpu_replicate = "replicate"
|
||||
// CHECK-NOT: device = "device"
|
||||
// CHECK-SAME: name = "nameB"
|
||||
// CHECK: })
|
||||
// CHECK-NOT: _tpu_replicate = "replicate"
|
||||
// CHECK-NOT: device = "device"
|
||||
// CHECK: name = "nameA"
|
||||
// CHECK: tf_device.return
|
||||
|
||||
// A nested cluster op using result of another cluster op. In the below, opA and
|
||||
// opB go in a cluster, and opD stays outside.
|
||||
// CHECK-LABEL: func @cluster_nested_op_using_other_op
|
||||
func @cluster_nested_op_using_other_op() {
|
||||
%0 = "tf.opA"() { _tpu_replicate = "foo" } : () -> tensor<i32>
|
||||
"tf.opB"() ({
|
||||
"tf.opC"(%0) : (tensor<i32>) -> ()
|
||||
}) { _tpu_replicate = "foo" } : () -> ()
|
||||
"tf.opD"(%0) : (tensor<i32>) -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: [[CLUSTER:%.*]] = "tf_device.cluster"() ( {
|
||||
// CHECK: [[OPA:%.*]] = "tf.opA"() : () -> tensor<i32>
|
||||
// CHECK: "tf.opB"() ( {
|
||||
// CHECK: "tf.opC"([[OPA]])
|
||||
// CHECK: tf_device.return [[OPA]]
|
||||
// CHECK: "tf.opD"([[CLUSTER]])
|
||||
|
||||
// Preceding user is using resource updated by a nested op.
|
||||
!tf_res = type tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK-LABEL: func @cluster_nested_op_updating_resource
|
||||
func @cluster_nested_op_updating_resource() {
|
||||
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> !tf_res
|
||||
|
||||
"tf.opA"() ({
|
||||
"tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor<f32>) -> ()
|
||||
"tf.terminator"() : () -> ()
|
||||
}) { _tpu_replicate = "foo" } : () -> ()
|
||||
"tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor<f32>) -> ()
|
||||
"tf.opB"() { _tpu_replicate = "foo" } : () -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: [[CONST:%.*]] = "tf.Const"
|
||||
// CHECK: [[VAR:%.*]] = "tf.VarHandleOp"
|
||||
// CHECK: "tf_device.cluster"() ( {
|
||||
// CHECK: "tf.opA"() ( {
|
||||
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
|
||||
// CHECK: })
|
||||
// CHECK: "tf.opB"()
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: })
|
||||
// CHECK-SAME: _tpu_replicate = "foo"
|
||||
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
|
||||
|
||||
// Preceding user is using resource updated by the cluster within a nested op.
|
||||
// Resource is updated by a cluster op, and opA (not in cluster) is using the
|
||||
// resource in a nested op. We expect opA to be after the cluster.
|
||||
// CHECK-LABEL: func @cluster_nested_op_using_resource
|
||||
func @cluster_nested_op_using_resource() {
|
||||
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.VarHandleOp"() {container = "", shape = #tf.shape<>, shared_name = "x"} : () -> !tf_res
|
||||
"tf.AssignAddVariableOp"(%1, %0) { _tpu_replicate = "foo" } : (!tf_res, tensor<f32>) -> ()
|
||||
"tf.opA"() ({
|
||||
"tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor<f32>) -> ()
|
||||
"tf.terminator"() : () -> ()
|
||||
}) : () -> ()
|
||||
"tf.opB"() { _tpu_replicate = "foo" } : () -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: [[CONST:%.*]] = "tf.Const"
|
||||
// CHECK: [[VAR:%.*]] = "tf.VarHandleOp"
|
||||
// CHECK: "tf_device.cluster"() ( {
|
||||
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
|
||||
// CHECK: "tf.opB"()
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: })
|
||||
// CHECK-SAME: _tpu_replicate = "foo"
|
||||
// CHECK: "tf.opA"() ( {
|
||||
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
@ -407,18 +507,6 @@ func @bad_num_replicas() {
|
||||
// -----
|
||||
|
||||
|
||||
// Test that functions without TPUReplicateMetadata op are skipped without
|
||||
// error
|
||||
// CHECK-LABEL: func @missing_metadata_op
|
||||
func @missing_metadata_op() {
|
||||
// expected-warning@+1 {{TPUReplicateMetadata for associated '_tpu_replicate' attribute 'replicate' is missing}}
|
||||
%0 = "tf.opA"() {_tpu_replicate = "replicate"} : () -> tensor<i1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Test cluster with TPUReplicatedInput where the number of operands does not
|
||||
// match associated `num_replicas` attribute.
|
||||
func @mismatched_replicated_input(%arg0: tensor<i1>) {
|
||||
|
@ -71,9 +71,11 @@ constexpr char kBadTPUReplicateAttrMsg[] =
|
||||
using MetadataMap =
|
||||
llvm::SmallDenseMap<llvm::StringRef, MutableDictionaryAttr, 8>;
|
||||
|
||||
// A set of operations in a cluster.
|
||||
using ClusterOps = llvm::SmallSetVector<Operation*, 8>;
|
||||
|
||||
// Mapping for `_tpu_replicate` attribute to ops of a cluster.
|
||||
using ClusterMap = llvm::SmallDenseMap<llvm::StringRef,
|
||||
llvm::SmallSetVector<Operation*, 8>, 8>;
|
||||
using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, ClusterOps, 8>;
|
||||
|
||||
struct TPUClusterFormation
|
||||
: public TF::PerFunctionAggregateAnalysisConsumerPass<
|
||||
@ -91,42 +93,40 @@ struct TPUClusterFormation
|
||||
// attribute to its attributes and removes the ops. If multiple
|
||||
// TPUReplicateMetadata ops have the same `_tpu_replicate` attribute, an error
|
||||
// will be returned.
|
||||
LogicalResult CollectMetadata(Operation* op, MetadataMap* metadata_map) {
|
||||
auto result =
|
||||
op->walk([&](TF::TPUReplicateMetadataOp metadata_op) -> WalkResult {
|
||||
MutableDictionaryAttr attrs = metadata_op.getAttrs();
|
||||
LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) {
|
||||
// Just look at top-level operations in the block (not nested ones)
|
||||
for (Operation& op : llvm::make_early_inc_range(*block)) {
|
||||
auto metadata_op = dyn_cast<TF::TPUReplicateMetadataOp>(op);
|
||||
if (!metadata_op) continue;
|
||||
|
||||
// Missing or bad `_tpu_replicate` attribute.
|
||||
auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr);
|
||||
if (!tpu_replicate_attr)
|
||||
return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
|
||||
MutableDictionaryAttr attrs = metadata_op.getAttrs();
|
||||
|
||||
auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast<StringAttr>();
|
||||
if (!tpu_replicate_attr_str ||
|
||||
tpu_replicate_attr_str.getValue().empty())
|
||||
return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
|
||||
// Missing or bad `_tpu_replicate` attribute.
|
||||
auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr);
|
||||
if (!tpu_replicate_attr)
|
||||
return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
|
||||
|
||||
// Remove `name` attribute.
|
||||
attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext()));
|
||||
auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast<StringAttr>();
|
||||
if (!tpu_replicate_attr_str || tpu_replicate_attr_str.getValue().empty())
|
||||
return metadata_op.emitError() << kBadTPUReplicateAttrMsg;
|
||||
|
||||
auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(),
|
||||
std::move(attrs));
|
||||
// Remove `name` attribute.
|
||||
attrs.remove(Identifier::get(kNameAttr, metadata_op.getContext()));
|
||||
|
||||
// There are multiple TPUReplicateMetadata ops with the same
|
||||
// `_tpu_replicate` attribute.
|
||||
if (!it.second) {
|
||||
return metadata_op.emitError()
|
||||
<< "multiple TPUReplicateMetadata ops with the same '"
|
||||
<< kTPUReplicateAttr << "' attribute '"
|
||||
<< tpu_replicate_attr_str.getValue() << "' found";
|
||||
}
|
||||
auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(),
|
||||
std::move(attrs));
|
||||
|
||||
metadata_op.erase();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
// Return failure if the walk was interrupted.
|
||||
return failure(result.wasInterrupted());
|
||||
// There are multiple TPUReplicateMetadata ops with the same
|
||||
// `_tpu_replicate` attribute.
|
||||
if (!it.second) {
|
||||
return metadata_op.emitError()
|
||||
<< "multiple TPUReplicateMetadata ops with the same '"
|
||||
<< kTPUReplicateAttr << "' attribute '"
|
||||
<< tpu_replicate_attr_str.getValue() << "' found";
|
||||
}
|
||||
metadata_op.erase();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Collects and clusters ops with the same `_tpu_replicate` attribute. This will
|
||||
@ -154,12 +154,12 @@ void CollectResourceIdsFromOp(
|
||||
op.walk([&](Operation* inner_op) {
|
||||
for (Value operand : TF::filter_resources(inner_op->getOperands())) {
|
||||
if (resource_alias_analysis.IsUnknownResource(operand)) continue;
|
||||
auto ids = resource_alias_analysis.GetResourceUniqueIds(operand);
|
||||
const auto& ids = resource_alias_analysis.GetResourceUniqueIds(operand);
|
||||
observed_resource_ids.insert(ids.begin(), ids.end());
|
||||
}
|
||||
for (Value result : TF::filter_resources(inner_op->getResults())) {
|
||||
if (resource_alias_analysis.IsUnknownResource(result)) continue;
|
||||
auto ids = resource_alias_analysis.GetResourceUniqueIds(result);
|
||||
const auto& ids = resource_alias_analysis.GetResourceUniqueIds(result);
|
||||
observed_resource_ids.insert(ids.begin(), ids.end());
|
||||
}
|
||||
});
|
||||
@ -168,13 +168,12 @@ void CollectResourceIdsFromOp(
|
||||
// Checks if an op should be moved after a cluster. There may be users of a
|
||||
// cluster interleaved among the cluster ops.
|
||||
bool ShouldMoveOpAfterCluster(
|
||||
Block* block, Operation* op,
|
||||
const llvm::SmallSetVector<Operation*, 8>& cluster_ops,
|
||||
Block* block, Operation* op, const ClusterOps& cluster_ops,
|
||||
const llvm::SmallSetVector<Operation*, 8>& preceding_users,
|
||||
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis,
|
||||
const llvm::SmallDenseSet<int64_t>& observed_resource_ids) {
|
||||
auto result = op->walk([&](Operation* op) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
auto result = op->walk([&](Operation* inner_op) {
|
||||
for (Value operand : inner_op->getOperands()) {
|
||||
Operation* def = operand.getDefiningOp();
|
||||
// Operands may not have a defining op (BlockArgument) or is from a
|
||||
// different block.
|
||||
@ -188,7 +187,7 @@ bool ShouldMoveOpAfterCluster(
|
||||
}
|
||||
|
||||
// Check for uses of any resource in or after cluster.
|
||||
for (Value operand : TF::filter_resources(op->getOperands())) {
|
||||
for (Value operand : TF::filter_resources(inner_op->getOperands())) {
|
||||
if (resource_alias_analysis.IsUnknownResource(operand)) continue;
|
||||
auto ids = resource_alias_analysis.GetResourceUniqueIds(operand);
|
||||
for (const auto& id : ids)
|
||||
@ -208,13 +207,14 @@ bool ShouldMoveOpAfterCluster(
|
||||
// TODO(lyandy): Extend this to handle all side effecting ops while handling
|
||||
// transitive data dependencies.
|
||||
llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
|
||||
Block* block, const llvm::SmallSetVector<Operation*, 8>& cluster_ops,
|
||||
Block* block, const ClusterOps& cluster_ops,
|
||||
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
|
||||
llvm::SmallSetVector<Operation*, 8> preceding_users;
|
||||
llvm::SmallDenseSet<int64_t> observed_resource_ids;
|
||||
|
||||
for (Operation& op : llvm::make_range(Block::iterator(cluster_ops.front()),
|
||||
Block::iterator(cluster_ops.back()))) {
|
||||
auto front = Block::iterator(cluster_ops.front());
|
||||
auto back = Block::iterator(cluster_ops.back());
|
||||
for (Operation& op : llvm::make_range(front, back)) {
|
||||
if (cluster_ops.contains(&op)) {
|
||||
CollectResourceIdsFromOp(op, resource_alias_analysis,
|
||||
observed_resource_ids);
|
||||
@ -236,7 +236,7 @@ llvm::SmallSetVector<Operation*, 8> CollectClusterPrecedingUsers(
|
||||
// outside of the cluster (i.e. results of ops in the cluster are only consumed
|
||||
// by other ops in the cluster) are pruned.
|
||||
llvm::SmallVector<Value, 8> CollectClusterResults(
|
||||
Block* block, const llvm::SmallSetVector<Operation*, 8>& cluster_ops) {
|
||||
Block* block, const ClusterOps& cluster_ops) {
|
||||
llvm::SmallVector<Value, 8> results;
|
||||
|
||||
for (Operation* op : cluster_ops) {
|
||||
@ -255,61 +255,52 @@ llvm::SmallVector<Value, 8> CollectClusterResults(
|
||||
}
|
||||
|
||||
// Creates a `tf_device.cluster` to wrap cluster ops.
|
||||
tf_device::ClusterOp CreateOpForCluster(Operation* last_cluster_op,
|
||||
llvm::ArrayRef<Value> results) {
|
||||
tf_device::ClusterOp CreateClusterOp(
|
||||
Block* block, const ClusterOps& cluster_ops, llvm::ArrayRef<Value> results,
|
||||
llvm::ArrayRef<Operation*> preceding_users) {
|
||||
// `tf_device.cluster` will be placed at where the last op of the cluster is.
|
||||
Operation* last_cluster_op = cluster_ops.back();
|
||||
OpBuilder builder(last_cluster_op);
|
||||
|
||||
llvm::SmallVector<Type, 8> result_types;
|
||||
for (Value result : results) result_types.push_back(result.getType());
|
||||
|
||||
auto cluster = builder.create<tf_device::ClusterOp>(last_cluster_op->getLoc(),
|
||||
result_types);
|
||||
|
||||
cluster.body().push_back(new Block);
|
||||
Block* body = new Block;
|
||||
cluster.body().push_back(body);
|
||||
|
||||
// Move cluster ops to the cluster body. Also remove `_tpu_replicate` and
|
||||
// `device` attribute from ops in the cluster as that information will be
|
||||
// present in the `tf_device.cluster`. Do this for all ops including nested
|
||||
// ops.
|
||||
for (Operation* cluster_op : cluster_ops) {
|
||||
cluster_op->moveBefore(body, body->end());
|
||||
cluster_op->walk([&](Operation* inner_op) {
|
||||
inner_op->removeAttr(kTPUReplicateAttr);
|
||||
inner_op->removeAttr(kDeviceAttr);
|
||||
});
|
||||
}
|
||||
|
||||
// Add terminator.
|
||||
builder.setInsertionPointToEnd(&cluster.GetBody());
|
||||
builder.setInsertionPointToEnd(body);
|
||||
builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
|
||||
|
||||
return cluster;
|
||||
}
|
||||
|
||||
// Moves cluster ops to associated `tf_device.cluster` body.
|
||||
void MoveClusterOpsToCluster(
|
||||
tf_device::ClusterOp cluster,
|
||||
const llvm::SmallSetVector<Operation*, 8>& cluster_ops) {
|
||||
MLIRContext* context = cluster.getContext();
|
||||
Operation* terminator = cluster.GetBody().getTerminator();
|
||||
|
||||
for (Operation* cluster_op : cluster_ops) {
|
||||
// Remove `_tpu_replicate` and `device` attribute from ops in the cluster
|
||||
// as that information will be present in the `tf_device.cluster`.
|
||||
cluster_op->removeAttr(Identifier::get(kTPUReplicateAttr, context));
|
||||
cluster_op->removeAttr(Identifier::get(kDeviceAttr, context));
|
||||
cluster_op->moveBefore(terminator);
|
||||
}
|
||||
}
|
||||
|
||||
// Replaces uses of cluster ops results outside of cluster with the associated
|
||||
// `tf_device.cluster` results.
|
||||
void UpdateClusterResultExternalUses(tf_device::ClusterOp cluster,
|
||||
llvm::ArrayRef<Value> results) {
|
||||
Block& cluster_block = cluster.GetBody();
|
||||
// Replaces uses of cluster ops results outside of cluster with the associated
|
||||
// `tf_device.cluster` results.
|
||||
for (auto ret_vals : llvm::zip(results, cluster.getResults())) {
|
||||
Value old_ret = std::get<0>(ret_vals);
|
||||
Value new_ret = std::get<1>(ret_vals);
|
||||
for (auto& use : llvm::make_early_inc_range(old_ret.getUses()))
|
||||
if (!cluster_block.findAncestorOpInBlock(*use.getOwner()))
|
||||
use.set(new_ret);
|
||||
for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) {
|
||||
Operation* user = use.getOwner();
|
||||
if (!body->findAncestorOpInBlock(*user)) use.set(new_ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Moves users of cluster that are before the cluster to after the cluster.
|
||||
void MovePrecedingClusterUsers(tf_device::ClusterOp cluster,
|
||||
llvm::ArrayRef<Operation*> preceding_users) {
|
||||
// Move users of cluster that are before the cluster to after the cluster.
|
||||
Operation* op_after_cluster = cluster.getOperation()->getNextNode();
|
||||
for (Operation* user : preceding_users) user->moveBefore(op_after_cluster);
|
||||
return cluster;
|
||||
}
|
||||
|
||||
// Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
|
||||
@ -490,10 +481,29 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
||||
// attribute `num_replicas` is greater than 1.
|
||||
// 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`.
|
||||
LogicalResult FormClustersInBlock(
|
||||
Block* block, const MetadataMap& metadata_map,
|
||||
Block* block,
|
||||
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
|
||||
MetadataMap metadata_map;
|
||||
LogicalResult result = CollectMetadata(block, &metadata_map);
|
||||
if (failed(result)) return result;
|
||||
|
||||
// If there is no TPUReplicateMetadata op in this block, process blocks in
|
||||
// regions attached to the op's in the block.
|
||||
if (metadata_map.empty()) {
|
||||
for (Operation& op : *block) {
|
||||
for (Region& region : op.getRegions()) {
|
||||
if (!llvm::hasSingleElement(region))
|
||||
return op.emitOpError("Expected single block region");
|
||||
if (failed(
|
||||
FormClustersInBlock(®ion.front(), resource_alias_analysis)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
ClusterMap clusters;
|
||||
LogicalResult result = CollectAndGroupClusterOps(block, &clusters);
|
||||
result = CollectAndGroupClusterOps(block, &clusters);
|
||||
if (failed(result)) return result;
|
||||
|
||||
for (const auto& cluster_metadata_and_ops : clusters) {
|
||||
@ -518,14 +528,8 @@ LogicalResult FormClustersInBlock(
|
||||
llvm::SmallVector<Value, 8> results =
|
||||
CollectClusterResults(block, cluster_ops);
|
||||
|
||||
tf_device::ClusterOp cluster =
|
||||
CreateOpForCluster(cluster_ops.back(), results);
|
||||
|
||||
MoveClusterOpsToCluster(cluster, cluster_ops);
|
||||
|
||||
UpdateClusterResultExternalUses(cluster, results);
|
||||
|
||||
MovePrecedingClusterUsers(cluster, preceding_users.getArrayRef());
|
||||
tf_device::ClusterOp cluster = CreateClusterOp(
|
||||
block, cluster_ops, results, preceding_users.getArrayRef());
|
||||
|
||||
auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr);
|
||||
if (!num_replicas || !num_replicas.isa<mlir::IntegerAttr>())
|
||||
@ -548,13 +552,13 @@ LogicalResult FormClustersInBlock(
|
||||
void TPUClusterFormation::runOnFunction(
|
||||
FuncOp func,
|
||||
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
|
||||
MetadataMap metadata_map;
|
||||
if (failed(CollectMetadata(func, &metadata_map))) return signalPassFailure();
|
||||
if (!llvm::hasSingleElement(func)) {
|
||||
func.emitOpError("Expecting a single block function");
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
for (Block& block : func)
|
||||
if (failed(
|
||||
FormClustersInBlock(&block, metadata_map, resource_alias_analysis)))
|
||||
return signalPassFailure();
|
||||
if (failed(FormClustersInBlock(&func.front(), resource_alias_analysis)))
|
||||
return signalPassFailure();
|
||||
|
||||
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
|
||||
auto remove_result = func.walk([&](Operation* op) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user