From 572354c8fda7e06b02d0faf2c810084546e3933b Mon Sep 17 00:00:00 2001
From: Abdurrahman Akkas <akkas@google.com>
Date: Fri, 3 Jan 2020 14:07:25 -0800
Subject: [PATCH] Add graph pruning to flatbuffer importer.

When user specifies custom outputs, unsupported ops in the unused parts of the graph cause import to fail. With this change importer traverses subgraph and prunes ops that don't affect subgraph outputs.

PiperOrigin-RevId: 288046610
Change-Id: I849fb7d0d04ca8f2812bbd1f4037d46a2f66bedf
---
 .../compiler/mlir/lite/flatbuffer_import.cc   | 99 ++++++++++++++++---
 .../compiler/mlir/lite/flatbuffer_import.h    |  5 +-
 .../tests/flatbuffer2mlir/output_arrays.mlir  |  2 +
 .../lite/tests/flatbuffer2mlir/pruning.mlir   | 19 ++++
 4 files changed, 113 insertions(+), 12 deletions(-)
 create mode 100644 tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir

diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index f841e7cf7a1..43974e02bba 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include <algorithm>
 #include <cctype>
+#include <cstdint>
 #include <iostream>
 #include <sstream>
 #include <string>
@@ -103,12 +104,26 @@ using llvm::cl::opt;
 // Commandline flag to enable the control of flatbuffer import.
 bool use_external_constant;
 
+// Commandline flag to enable graph pruning.
+bool experimental_prune_unreachable_nodes_unconditionally;
+
 // NOLINTNEXTLINE
 static opt<bool, true> use_external_constant_flag(
     "use-external-constant",
     llvm::cl::desc("Use external constant during flatbuffer import"),
     llvm::cl::location(use_external_constant), llvm::cl::init(false));
 
+// TODO(b/147111261): After the importer supports generic custom ops, we should
+// change the flag to a more lightwise flag, e.g.
+// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
+// the operations.
+// NOLINTNEXTLINE
+static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
+    "experimental-prune-unreachable-nodes-unconditionally",
+    llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
+    llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
+    llvm::cl::init(false));
+
 namespace {
 bool IsScalar(const TensorT& tensor) {
   // TODO(b/138222071) We can't distinguish scalars and unranked tensors
@@ -268,8 +283,10 @@ StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
     if (custom_name == "MaxUnpooling2D") {
       return std::string("tfl.max_unpooling_2d");
     }
-    return errors::Unimplemented("unsupported custom operation: ",
-                                 opcode.custom_code);
+    // Use an unsupported op name instead of throwing an error here in case the
+    // op is pruned during the import.
+    return std::string(
+        llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str());
   }
   if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
     return std::string("tf.If");
@@ -645,6 +662,49 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
       name, builder->getStringAttr(llvm::join(tensor_names, ",")));
 }
 
+// Given a list of output indices, traverses the subgraph and returns the set of
+// ops that are ancestors of the output tensors.
+StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
+    const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
+  // Create a map from tensor index to defining op.
+  absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
+  for (const auto& op : subgraph.operators) {
+    for (int32_t output : op->outputs) {
+      defining_op[output] = op.get();
+    }
+  }
+
+  std::vector<const tflite::OperatorT*> queue;
+  for (int32_t output : output_indices) {
+    if (auto& op = defining_op[output]) {
+      queue.push_back(op);
+    } else {
+      return errors::InvalidArgument("Output tensor doesn't have defining op");
+    }
+  }
+
+  // Traverse the graph towards inputs.
+  absl::flat_hash_set<const tflite::OperatorT*> visited;
+  while (!queue.empty()) {
+    const tflite::OperatorT* op = queue.back();
+    queue.pop_back();
+    if (!visited.insert(op).second) {
+      // The node has already been visited.
+      continue;
+    }
+
+    for (int32_t input : op->inputs) {
+      // Input tensor may not have a defining op in case it is a subgraph input
+      // or a constant tensor.
+      if (auto& op = defining_op[input]) {
+        queue.push_back(op);
+      }
+    }
+  }
+
+  return visited;
+}
+
 // Build a FuncOp from a tflite SubGraph
 // The op_names are a mapping from indexes into the TFLite operators array to
 // the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
@@ -661,7 +721,8 @@ StatusOr<FuncOp> ConvertSubgraph(
     const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
     Location base_loc, Builder builder,
     const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
-    bool use_external_constant) {
+    bool use_external_constant,
+    bool experimental_prune_unreachable_nodes_unconditionally) {
   llvm::SmallVector<mlir::Type, 2> ret_types;
   llvm::SmallVector<mlir::Type, 4> input_types;
 
@@ -757,8 +818,19 @@ StatusOr<FuncOp> ConvertSubgraph(
     func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
   }
 
+  absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
+  if (experimental_prune_unreachable_nodes_unconditionally) {
+    TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
+                        PruneSubgraph(subgraph, func_outputs));
+  }
+
   // Construct MLIR operators from TFLite operators
   for (auto& op : subgraph.operators) {
+    if (experimental_prune_unreachable_nodes_unconditionally &&
+        !pruned_subgraph_ops.contains(op)) {
+      continue;
+    }
+
     for (auto input_num : op->inputs) {
       // The operators in a graph are topologically sorted
       // and so if no previous operation has produced a tensor
@@ -863,7 +935,8 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
 OwningModuleRef tflite::FlatBufferToMlir(
     absl::string_view buffer, MLIRContext* context, Location base_loc,
     const std::vector<std::string>& ordered_output_arrays,
-    bool use_external_constant) {
+    bool use_external_constant,
+    bool experimental_prune_unreachable_nodes_unconditionally) {
   auto model_ptr =
       FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
   if (nullptr == model_ptr) {
@@ -918,7 +991,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
         // TODO(b/131175224,b/132239787) Support multiple entry points
         builder, ordered_output_arrays,
         /*is_entry_point=*/e.index() == 0,
-        /*use_external_constant=*/use_external_constant);
+        /*use_external_constant=*/use_external_constant,
+        experimental_prune_unreachable_nodes_unconditionally);
     if (!func_or_error.ok()) {
       return emitError(base_loc, "could not translate function ")
                  << subgraph->name,
@@ -931,9 +1005,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
   return OwningModuleRef(module);
 }
 
-static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
-                                                 MLIRContext* context,
-                                                 bool use_external_constant) {
+static OwningModuleRef FlatBufferFileToMlirTrans(
+    llvm::SourceMgr* source_mgr, MLIRContext* context,
+    bool use_external_constant,
+    bool experimental_prune_unreachable_nodes_unconditionally) {
   const llvm::MemoryBuffer* input =
       source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
   std::string error;
@@ -950,12 +1025,14 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
 
   return tflite::FlatBufferToMlir(
       absl::string_view(input->getBufferStart(), input->getBufferSize()),
-      context, loc, outputs, use_external_constant);
+      context, loc, outputs, use_external_constant,
+      experimental_prune_unreachable_nodes_unconditionally);
 }
 
 static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
     "tflite-flatbuffer-to-mlir",
     [](llvm::SourceMgr& source_mgr, MLIRContext* context) {
-      return FlatBufferFileToMlirTrans(&source_mgr, context,
-                                       use_external_constant);
+      return FlatBufferFileToMlirTrans(
+          &source_mgr, context, use_external_constant,
+          experimental_prune_unreachable_nodes_unconditionally);
     });
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.h b/tensorflow/compiler/mlir/lite/flatbuffer_import.h
index 92a4a10adbb..e3210c6d03f 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.h
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.h
@@ -31,11 +31,14 @@ namespace tflite {
 // on failure, and more specific errors will be emitted via the context.
 // If `use_external_constant` is true, it will create `tfl.external_const`
 // instead of `tfl.const`.
+// If `experimental_prune_unreachable_nodes_unconditionally` is true, nodes that
+// are not ancestors of the output nodes will be pruned.
 mlir::OwningModuleRef FlatBufferToMlir(
     absl::string_view buffer, mlir::MLIRContext* context,
     mlir::Location base_loc,
     const std::vector<std::string>& ordered_output_arrays,
-    bool use_external_constant = false);
+    bool use_external_constant = false,
+    bool experimental_prune_unreachable_nodes_unconditionally = false);
 }  // namespace tflite
 
 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir
index d228cc06a88..20df2f75732 100644
--- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir
@@ -11,6 +11,8 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
   %3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
   // CHECK: %[[EXP:.*]] = "tfl.exp"
   %4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
+  // tfl.neg should not be pruned
+  // CHECK: %[[NEG:.*]] = "tfl.neg"
   %5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
   // CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
   return %5 : tensor<4xf32>
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir
new file mode 100644
index 00000000000..0d7f911f282
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir
@@ -0,0 +1,19 @@
+// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -output-arrays=mul,exp,div --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
+// Confirm graph pruning.
+
+func @main(tensor<4xf32>) -> tensor<4xf32> {
+^bb0(%arg0: tensor<4xf32>):
+  %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
+  %1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
+  // CHECK: %[[MUL:.*]] = tfl.mul
+  %2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
+  // CHECK: %[[DIV:.*]] = tfl.div
+  %3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
+  // CHECK: %[[EXP:.*]] = "tfl.exp"
+  %4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
+  // tfl.neg should be pruned
+  // CHECK-NOT: "tfl.neg"
+  %5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
+  // CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
+  return %5 : tensor<4xf32>
+}