NFC: Minor modifications to the tpu-outside-compilation-cluster pass
Specifically, - Filter out marked ops for outside compilation and run cluster assignment only on them. - Use llvm::concat to combine data and control dependency - Exit early if the op is not safe to add This will help in the follow-up changes to the pass. PiperOrigin-RevId: 347699261 Change-Id: I8c948ea8f7d19086b95564efabcb66fc06a0153a
This commit is contained in:
parent
5a63699c77
commit
5ff1abdd7a
@ -127,46 +127,33 @@ class OutsideCompiledCluster {
|
|||||||
// Checks if it is safe for `op` to be merged into this cluster.
|
// Checks if it is safe for `op` to be merged into this cluster.
|
||||||
bool IsSafeToAdd(Operation* op,
|
bool IsSafeToAdd(Operation* op,
|
||||||
const TF::SideEffectAnalysis::Info& side_effect_analysis) {
|
const TF::SideEffectAnalysis::Info& side_effect_analysis) {
|
||||||
// If the op is not marked for outside compilation it doesn't belong in a
|
|
||||||
// cluster.
|
|
||||||
if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
if (host_cluster_ops_.empty()) return true;
|
if (host_cluster_ops_.empty()) return true;
|
||||||
|
|
||||||
// If there is an intermediate data or side effect dependency between the op
|
// If there is an intermediate data or side effect dependency between the op
|
||||||
// and ops in the cluster, it's not safe to add.
|
// and ops in the cluster, it's not safe to add.
|
||||||
llvm::SmallSetVector<Operation*, 4> op_stack;
|
llvm::SmallSetVector<Operation*, 4> op_stack;
|
||||||
for (auto* user : op->getUsers()) {
|
|
||||||
if (!host_cluster_ops_.contains(user)) op_stack.insert(user);
|
// Materialize data dependencies as the llvm::concat doesn't support
|
||||||
}
|
// non-materialized iteration.
|
||||||
for (auto* successor : side_effect_analysis.DirectControlSuccessors(op)) {
|
auto data_deps = llvm::to_vector<4>(op->getUsers());
|
||||||
if (!host_cluster_ops_.contains(successor)) op_stack.insert(successor);
|
llvm::SmallVector<Operation*, 4> control_deps =
|
||||||
}
|
side_effect_analysis.DirectControlSuccessors(op);
|
||||||
bool safe_to_add = true;
|
for (auto* dep : llvm::concat<Operation*>(data_deps, control_deps)) {
|
||||||
while (!op_stack.empty()) {
|
if (!host_cluster_ops_.contains(dep)) op_stack.insert(dep);
|
||||||
auto* next_op = op_stack.pop_back_val();
|
|
||||||
for (auto* user : next_op->getUsers()) {
|
|
||||||
if (host_cluster_ops_.contains(user)) {
|
|
||||||
safe_to_add = false;
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
op_stack.insert(user);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto* successor :
|
|
||||||
side_effect_analysis.DirectControlSuccessors(next_op)) {
|
|
||||||
if (host_cluster_ops_.contains(successor)) {
|
|
||||||
safe_to_add = false;
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
op_stack.insert(successor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!safe_to_add) break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return safe_to_add;
|
while (!op_stack.empty()) {
|
||||||
|
auto* next_op = op_stack.pop_back_val();
|
||||||
|
auto data_deps = llvm::to_vector<4>(next_op->getUsers());
|
||||||
|
llvm::SmallVector<Operation*, 4> control_deps =
|
||||||
|
side_effect_analysis.DirectControlSuccessors(next_op);
|
||||||
|
for (auto* dep : llvm::concat<Operation*>(data_deps, control_deps)) {
|
||||||
|
if (host_cluster_ops_.contains(dep)) return false;
|
||||||
|
op_stack.insert(dep);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// `host_cluster_op_` stores a set of ops that will be grouped and computed
|
// `host_cluster_op_` stores a set of ops that will be grouped and computed
|
||||||
@ -183,14 +170,15 @@ void TPUOutsideCompilationCluster::runOnFunction(
|
|||||||
int cluster_counter = 0;
|
int cluster_counter = 0;
|
||||||
|
|
||||||
func.walk([&](tf_device::ClusterOp tpu_cluster) {
|
func.walk([&](tf_device::ClusterOp tpu_cluster) {
|
||||||
llvm::SmallVector<Operation*, 4> tpu_cluster_ops;
|
llvm::SmallVector<Operation*, 4> outside_ops;
|
||||||
tpu_cluster_ops.reserve(tpu_cluster.getBody()->getOperations().size());
|
tpu_cluster.walk([&](Operation* op) {
|
||||||
|
if (op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
|
||||||
tpu_cluster.walk([&](Operation* op) { tpu_cluster_ops.emplace_back(op); });
|
outside_ops.emplace_back(op);
|
||||||
|
});
|
||||||
|
|
||||||
// In order to cluster ops feeding results to the same operation, traverse
|
// In order to cluster ops feeding results to the same operation, traverse
|
||||||
// the ops in reverse order.
|
// the ops in reverse order.
|
||||||
for (Operation* op : llvm::reverse(tpu_cluster_ops)) {
|
for (Operation* op : llvm::reverse(outside_ops)) {
|
||||||
// Try to add the op to existing clusters.
|
// Try to add the op to existing clusters.
|
||||||
bool added = false;
|
bool added = false;
|
||||||
for (auto& cluster : clusters)
|
for (auto& cluster : clusters)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user