NFC: Minor modifications to the tpu-outside-compilation-cluster pass

Specifically,
- Filter out marked ops for outside compilation and run cluster assignment only on them.
- Use llvm::concat to combine data and control dependency
- Exit early if the op is not safe to add

This will help in the follow-up changes to the pass.

PiperOrigin-RevId: 347699261
Change-Id: I8c948ea8f7d19086b95564efabcb66fc06a0153a
This commit is contained in:
Smit Hinsu 2020-12-15 14:58:52 -08:00 committed by TensorFlower Gardener
parent 5a63699c77
commit 5ff1abdd7a

View File

@ -127,46 +127,33 @@ class OutsideCompiledCluster {
// Checks if it is safe for `op` to be merged into this cluster.
bool IsSafeToAdd(Operation* op,
const TF::SideEffectAnalysis::Info& side_effect_analysis) {
// If the op is not marked for outside compilation it doesn't belong in a
// cluster.
if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
return false;
if (host_cluster_ops_.empty()) return true;
// If there is an intermediate data or side effect dependency between the op
// and ops in the cluster, it's not safe to add.
llvm::SmallSetVector<Operation*, 4> op_stack;
for (auto* user : op->getUsers()) {
if (!host_cluster_ops_.contains(user)) op_stack.insert(user);
}
for (auto* successor : side_effect_analysis.DirectControlSuccessors(op)) {
if (!host_cluster_ops_.contains(successor)) op_stack.insert(successor);
}
bool safe_to_add = true;
while (!op_stack.empty()) {
auto* next_op = op_stack.pop_back_val();
for (auto* user : next_op->getUsers()) {
if (host_cluster_ops_.contains(user)) {
safe_to_add = false;
break;
} else {
op_stack.insert(user);
}
}
for (auto* successor :
side_effect_analysis.DirectControlSuccessors(next_op)) {
if (host_cluster_ops_.contains(successor)) {
safe_to_add = false;
break;
} else {
op_stack.insert(successor);
}
}
if (!safe_to_add) break;
// Materialize data dependencies as the llvm::concat doesn't support
// non-materialized iteration.
auto data_deps = llvm::to_vector<4>(op->getUsers());
llvm::SmallVector<Operation*, 4> control_deps =
side_effect_analysis.DirectControlSuccessors(op);
for (auto* dep : llvm::concat<Operation*>(data_deps, control_deps)) {
if (!host_cluster_ops_.contains(dep)) op_stack.insert(dep);
}
return safe_to_add;
while (!op_stack.empty()) {
auto* next_op = op_stack.pop_back_val();
auto data_deps = llvm::to_vector<4>(next_op->getUsers());
llvm::SmallVector<Operation*, 4> control_deps =
side_effect_analysis.DirectControlSuccessors(next_op);
for (auto* dep : llvm::concat<Operation*>(data_deps, control_deps)) {
if (host_cluster_ops_.contains(dep)) return false;
op_stack.insert(dep);
}
}
return true;
}
// `host_cluster_op_` stores a set of ops that will be grouped and computed
@ -183,14 +170,15 @@ void TPUOutsideCompilationCluster::runOnFunction(
int cluster_counter = 0;
func.walk([&](tf_device::ClusterOp tpu_cluster) {
llvm::SmallVector<Operation*, 4> tpu_cluster_ops;
tpu_cluster_ops.reserve(tpu_cluster.getBody()->getOperations().size());
tpu_cluster.walk([&](Operation* op) { tpu_cluster_ops.emplace_back(op); });
llvm::SmallVector<Operation*, 4> outside_ops;
tpu_cluster.walk([&](Operation* op) {
if (op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
outside_ops.emplace_back(op);
});
// In order to cluster ops feeding results to the same operation, traverse
// the ops in reverse order.
for (Operation* op : llvm::reverse(tpu_cluster_ops)) {
for (Operation* op : llvm::reverse(outside_ops)) {
// Try to add the op to existing clusters.
bool added = false;
for (auto& cluster : clusters)