diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index f841e7cf7a1..43974e02bba 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -103,12 +104,26 @@ using llvm::cl::opt; // Commandline flag to enable the control of flatbuffer import. bool use_external_constant; +// Commandline flag to enable graph pruning. +bool experimental_prune_unreachable_nodes_unconditionally; + // NOLINTNEXTLINE static opt use_external_constant_flag( "use-external-constant", llvm::cl::desc("Use external constant during flatbuffer import"), llvm::cl::location(use_external_constant), llvm::cl::init(false)); +// TODO(b/147111261): After the importer supports generic custom ops, we should +// change the flag to a more lightwise flag, e.g. +// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune +// the operations. +// NOLINTNEXTLINE +static opt experimental_prune_unreachable_nodes_unconditionally_flg( + "experimental-prune-unreachable-nodes-unconditionally", + llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."), + llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally), + llvm::cl::init(false)); + namespace { bool IsScalar(const TensorT& tensor) { // TODO(b/138222071) We can't distinguish scalars and unranked tensors @@ -268,8 +283,10 @@ StatusOr OpNameForOpCode(const tflite::OperatorCodeT opcode) { if (custom_name == "MaxUnpooling2D") { return std::string("tfl.max_unpooling_2d"); } - return errors::Unimplemented("unsupported custom operation: ", - opcode.custom_code); + // Use an unsupported op name instead of throwing an error here in case the + // op is pruned during the import. + return std::string( + llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str()); } if (opcode.builtin_code == tflite::BuiltinOperator_IF) { return std::string("tf.If"); @@ -645,6 +662,49 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute( name, builder->getStringAttr(llvm::join(tensor_names, ","))); } +// Given a list of output indices, traverses the subgraph and returns the set of +// ops that are ancestors of the output tensors. +StatusOr> PruneSubgraph( + const tflite::SubGraphT& subgraph, ArrayRef output_indices) { + // Create a map from tensor index to defining op. + absl::flat_hash_map defining_op; + for (const auto& op : subgraph.operators) { + for (int32_t output : op->outputs) { + defining_op[output] = op.get(); + } + } + + std::vector queue; + for (int32_t output : output_indices) { + if (auto& op = defining_op[output]) { + queue.push_back(op); + } else { + return errors::InvalidArgument("Output tensor doesn't have defining op"); + } + } + + // Traverse the graph towards inputs. + absl::flat_hash_set visited; + while (!queue.empty()) { + const tflite::OperatorT* op = queue.back(); + queue.pop_back(); + if (!visited.insert(op).second) { + // The node has already been visited. + continue; + } + + for (int32_t input : op->inputs) { + // Input tensor may not have a defining op in case it is a subgraph input + // or a constant tensor. + if (auto& op = defining_op[input]) { + queue.push_back(op); + } + } + } + + return visited; +} + // Build a FuncOp from a tflite SubGraph // The op_names are a mapping from indexes into the TFLite operators array to // the operator name MLIR expects (tfl.foo_op). The buffers are directly taken @@ -661,7 +721,8 @@ StatusOr ConvertSubgraph( const std::vector>& buffers, Location base_loc, Builder builder, const std::vector& ordered_output_arrays, bool is_entry_point, - bool use_external_constant) { + bool use_external_constant, + bool experimental_prune_unreachable_nodes_unconditionally) { llvm::SmallVector ret_types; llvm::SmallVector input_types; @@ -757,8 +818,19 @@ StatusOr ConvertSubgraph( func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes)); } + absl::flat_hash_set pruned_subgraph_ops; + if (experimental_prune_unreachable_nodes_unconditionally) { + TF_ASSIGN_OR_RETURN(pruned_subgraph_ops, + PruneSubgraph(subgraph, func_outputs)); + } + // Construct MLIR operators from TFLite operators for (auto& op : subgraph.operators) { + if (experimental_prune_unreachable_nodes_unconditionally && + !pruned_subgraph_ops.contains(op)) { + continue; + } + for (auto input_num : op->inputs) { // The operators in a graph are topologically sorted // and so if no previous operation has produced a tensor @@ -863,7 +935,8 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) { OwningModuleRef tflite::FlatBufferToMlir( absl::string_view buffer, MLIRContext* context, Location base_loc, const std::vector& ordered_output_arrays, - bool use_external_constant) { + bool use_external_constant, + bool experimental_prune_unreachable_nodes_unconditionally) { auto model_ptr = FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length()); if (nullptr == model_ptr) { @@ -918,7 +991,8 @@ OwningModuleRef tflite::FlatBufferToMlir( // TODO(b/131175224,b/132239787) Support multiple entry points builder, ordered_output_arrays, /*is_entry_point=*/e.index() == 0, - /*use_external_constant=*/use_external_constant); + /*use_external_constant=*/use_external_constant, + experimental_prune_unreachable_nodes_unconditionally); if (!func_or_error.ok()) { return emitError(base_loc, "could not translate function ") << subgraph->name, @@ -931,9 +1005,10 @@ OwningModuleRef tflite::FlatBufferToMlir( return OwningModuleRef(module); } -static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr, - MLIRContext* context, - bool use_external_constant) { +static OwningModuleRef FlatBufferFileToMlirTrans( + llvm::SourceMgr* source_mgr, MLIRContext* context, + bool use_external_constant, + bool experimental_prune_unreachable_nodes_unconditionally) { const llvm::MemoryBuffer* input = source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); std::string error; @@ -950,12 +1025,14 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr, return tflite::FlatBufferToMlir( absl::string_view(input->getBufferStart(), input->getBufferSize()), - context, loc, outputs, use_external_constant); + context, loc, outputs, use_external_constant, + experimental_prune_unreachable_nodes_unconditionally); } static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( "tflite-flatbuffer-to-mlir", [](llvm::SourceMgr& source_mgr, MLIRContext* context) { - return FlatBufferFileToMlirTrans(&source_mgr, context, - use_external_constant); + return FlatBufferFileToMlirTrans( + &source_mgr, context, use_external_constant, + experimental_prune_unreachable_nodes_unconditionally); }); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.h b/tensorflow/compiler/mlir/lite/flatbuffer_import.h index 92a4a10adbb..e3210c6d03f 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.h @@ -31,11 +31,14 @@ namespace tflite { // on failure, and more specific errors will be emitted via the context. // If `use_external_constant` is true, it will create `tfl.external_const` // instead of `tfl.const`. +// If `experimental_prune_unreachable_nodes_unconditionally` is true, nodes that +// are not ancestors of the output nodes will be pruned. mlir::OwningModuleRef FlatBufferToMlir( absl::string_view buffer, mlir::MLIRContext* context, mlir::Location base_loc, const std::vector& ordered_output_arrays, - bool use_external_constant = false); + bool use_external_constant = false, + bool experimental_prune_unreachable_nodes_unconditionally = false); } // namespace tflite #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir index d228cc06a88..20df2f75732 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir @@ -11,6 +11,8 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { %3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div") // CHECK: %[[EXP:.*]] = "tfl.exp" %4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp") + // tfl.neg should not be pruned + // CHECK: %[[NEG:.*]] = "tfl.neg" %5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg") // CHECK: return %[[MUL]], %[[EXP]], %[[DIV]] return %5 : tensor<4xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir new file mode 100644 index 00000000000..0d7f911f282 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir @@ -0,0 +1,19 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -output-arrays=mul,exp,div --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s +// Confirm graph pruning. + +func @main(tensor<4xf32>) -> tensor<4xf32> { +^bb0(%arg0: tensor<4xf32>): + %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference") + // CHECK: %[[MUL:.*]] = tfl.mul + %2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul") + // CHECK: %[[DIV:.*]] = tfl.div + %3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div") + // CHECK: %[[EXP:.*]] = "tfl.exp" + %4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp") + // tfl.neg should be pruned + // CHECK-NOT: "tfl.neg" + %5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg") + // CHECK: return %[[MUL]], %[[EXP]], %[[DIV]] + return %5 : tensor<4xf32> +}