From 5ff1abdd7ad67e6164104018321421028e1f4caa Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Tue, 15 Dec 2020 14:58:52 -0800 Subject: [PATCH] NFC: Minor modifications to the tpu-outside-compilation-cluster pass Specifically, - Filter out marked ops for outside compilation and run cluster assignment only on them. - Use llvm::concat to combine data and control dependency - Exit early if the op is not safe to add This will help in the follow-up changes to the pass. PiperOrigin-RevId: 347699261 Change-Id: I8c948ea8f7d19086b95564efabcb66fc06a0153a --- .../tpu_outside_compilation_cluster.cc | 64 ++++++++----------- 1 file changed, 26 insertions(+), 38 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc index 7f0acbd9aea..e8d61b468d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc @@ -127,46 +127,33 @@ class OutsideCompiledCluster { // Checks if it is safe for `op` to be merged into this cluster. bool IsSafeToAdd(Operation* op, const TF::SideEffectAnalysis::Info& side_effect_analysis) { - // If the op is not marked for outside compilation it doesn't belong in a - // cluster. - if (!op->getAttrOfType(kXlaOutsideCompilationAttr)) - return false; - if (host_cluster_ops_.empty()) return true; // If there is an intermediate data or side effect dependency between the op // and ops in the cluster, it's not safe to add. llvm::SmallSetVector op_stack; - for (auto* user : op->getUsers()) { - if (!host_cluster_ops_.contains(user)) op_stack.insert(user); - } - for (auto* successor : side_effect_analysis.DirectControlSuccessors(op)) { - if (!host_cluster_ops_.contains(successor)) op_stack.insert(successor); - } - bool safe_to_add = true; - while (!op_stack.empty()) { - auto* next_op = op_stack.pop_back_val(); - for (auto* user : next_op->getUsers()) { - if (host_cluster_ops_.contains(user)) { - safe_to_add = false; - break; - } else { - op_stack.insert(user); - } - } - for (auto* successor : - side_effect_analysis.DirectControlSuccessors(next_op)) { - if (host_cluster_ops_.contains(successor)) { - safe_to_add = false; - break; - } else { - op_stack.insert(successor); - } - } - if (!safe_to_add) break; + + // Materialize data dependencies as the llvm::concat doesn't support + // non-materialized iteration. + auto data_deps = llvm::to_vector<4>(op->getUsers()); + llvm::SmallVector control_deps = + side_effect_analysis.DirectControlSuccessors(op); + for (auto* dep : llvm::concat(data_deps, control_deps)) { + if (!host_cluster_ops_.contains(dep)) op_stack.insert(dep); } - return safe_to_add; + while (!op_stack.empty()) { + auto* next_op = op_stack.pop_back_val(); + auto data_deps = llvm::to_vector<4>(next_op->getUsers()); + llvm::SmallVector control_deps = + side_effect_analysis.DirectControlSuccessors(next_op); + for (auto* dep : llvm::concat(data_deps, control_deps)) { + if (host_cluster_ops_.contains(dep)) return false; + op_stack.insert(dep); + } + } + + return true; } // `host_cluster_op_` stores a set of ops that will be grouped and computed @@ -183,14 +170,15 @@ void TPUOutsideCompilationCluster::runOnFunction( int cluster_counter = 0; func.walk([&](tf_device::ClusterOp tpu_cluster) { - llvm::SmallVector tpu_cluster_ops; - tpu_cluster_ops.reserve(tpu_cluster.getBody()->getOperations().size()); - - tpu_cluster.walk([&](Operation* op) { tpu_cluster_ops.emplace_back(op); }); + llvm::SmallVector outside_ops; + tpu_cluster.walk([&](Operation* op) { + if (op->getAttrOfType(kXlaOutsideCompilationAttr)) + outside_ops.emplace_back(op); + }); // In order to cluster ops feeding results to the same operation, traverse // the ops in reverse order. - for (Operation* op : llvm::reverse(tpu_cluster_ops)) { + for (Operation* op : llvm::reverse(outside_ops)) { // Try to add the op to existing clusters. bool added = false; for (auto& cluster : clusters)