Add graph pruning to flatbuffer importer.
When user specifies custom outputs, unsupported ops in the unused parts of the graph cause import to fail. With this change importer traverses subgraph and prunes ops that don't affect subgraph outputs. PiperOrigin-RevId: 288046610 Change-Id: I849fb7d0d04ca8f2812bbd1f4037d46a2f66bedf
This commit is contained in:
parent
728967326e
commit
572354c8fd
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cctype>
|
#include <cctype>
|
||||||
|
#include <cstdint>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -103,12 +104,26 @@ using llvm::cl::opt;
|
|||||||
// Commandline flag to enable the control of flatbuffer import.
|
// Commandline flag to enable the control of flatbuffer import.
|
||||||
bool use_external_constant;
|
bool use_external_constant;
|
||||||
|
|
||||||
|
// Commandline flag to enable graph pruning.
|
||||||
|
bool experimental_prune_unreachable_nodes_unconditionally;
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
static opt<bool, true> use_external_constant_flag(
|
static opt<bool, true> use_external_constant_flag(
|
||||||
"use-external-constant",
|
"use-external-constant",
|
||||||
llvm::cl::desc("Use external constant during flatbuffer import"),
|
llvm::cl::desc("Use external constant during flatbuffer import"),
|
||||||
llvm::cl::location(use_external_constant), llvm::cl::init(false));
|
llvm::cl::location(use_external_constant), llvm::cl::init(false));
|
||||||
|
|
||||||
|
// TODO(b/147111261): After the importer supports generic custom ops, we should
|
||||||
|
// change the flag to a more lightwise flag, e.g.
|
||||||
|
// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune
|
||||||
|
// the operations.
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
|
||||||
|
"experimental-prune-unreachable-nodes-unconditionally",
|
||||||
|
llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."),
|
||||||
|
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
|
||||||
|
llvm::cl::init(false));
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
bool IsScalar(const TensorT& tensor) {
|
bool IsScalar(const TensorT& tensor) {
|
||||||
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
||||||
@ -268,8 +283,10 @@ StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
|
|||||||
if (custom_name == "MaxUnpooling2D") {
|
if (custom_name == "MaxUnpooling2D") {
|
||||||
return std::string("tfl.max_unpooling_2d");
|
return std::string("tfl.max_unpooling_2d");
|
||||||
}
|
}
|
||||||
return errors::Unimplemented("unsupported custom operation: ",
|
// Use an unsupported op name instead of throwing an error here in case the
|
||||||
opcode.custom_code);
|
// op is pruned during the import.
|
||||||
|
return std::string(
|
||||||
|
llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str());
|
||||||
}
|
}
|
||||||
if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
|
if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
|
||||||
return std::string("tf.If");
|
return std::string("tf.If");
|
||||||
@ -645,6 +662,49 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
|
|||||||
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Given a list of output indices, traverses the subgraph and returns the set of
|
||||||
|
// ops that are ancestors of the output tensors.
|
||||||
|
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
|
||||||
|
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
|
||||||
|
// Create a map from tensor index to defining op.
|
||||||
|
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
|
||||||
|
for (const auto& op : subgraph.operators) {
|
||||||
|
for (int32_t output : op->outputs) {
|
||||||
|
defining_op[output] = op.get();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const tflite::OperatorT*> queue;
|
||||||
|
for (int32_t output : output_indices) {
|
||||||
|
if (auto& op = defining_op[output]) {
|
||||||
|
queue.push_back(op);
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("Output tensor doesn't have defining op");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse the graph towards inputs.
|
||||||
|
absl::flat_hash_set<const tflite::OperatorT*> visited;
|
||||||
|
while (!queue.empty()) {
|
||||||
|
const tflite::OperatorT* op = queue.back();
|
||||||
|
queue.pop_back();
|
||||||
|
if (!visited.insert(op).second) {
|
||||||
|
// The node has already been visited.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int32_t input : op->inputs) {
|
||||||
|
// Input tensor may not have a defining op in case it is a subgraph input
|
||||||
|
// or a constant tensor.
|
||||||
|
if (auto& op = defining_op[input]) {
|
||||||
|
queue.push_back(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return visited;
|
||||||
|
}
|
||||||
|
|
||||||
// Build a FuncOp from a tflite SubGraph
|
// Build a FuncOp from a tflite SubGraph
|
||||||
// The op_names are a mapping from indexes into the TFLite operators array to
|
// The op_names are a mapping from indexes into the TFLite operators array to
|
||||||
// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
|
// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
|
||||||
@ -661,7 +721,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
||||||
Location base_loc, Builder builder,
|
Location base_loc, Builder builder,
|
||||||
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
|
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
|
||||||
bool use_external_constant) {
|
bool use_external_constant,
|
||||||
|
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
llvm::SmallVector<mlir::Type, 2> ret_types;
|
llvm::SmallVector<mlir::Type, 2> ret_types;
|
||||||
llvm::SmallVector<mlir::Type, 4> input_types;
|
llvm::SmallVector<mlir::Type, 4> input_types;
|
||||||
|
|
||||||
@ -757,8 +818,19 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
|
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
||||||
|
if (experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
|
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
|
||||||
|
PruneSubgraph(subgraph, func_outputs));
|
||||||
|
}
|
||||||
|
|
||||||
// Construct MLIR operators from TFLite operators
|
// Construct MLIR operators from TFLite operators
|
||||||
for (auto& op : subgraph.operators) {
|
for (auto& op : subgraph.operators) {
|
||||||
|
if (experimental_prune_unreachable_nodes_unconditionally &&
|
||||||
|
!pruned_subgraph_ops.contains(op)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto input_num : op->inputs) {
|
for (auto input_num : op->inputs) {
|
||||||
// The operators in a graph are topologically sorted
|
// The operators in a graph are topologically sorted
|
||||||
// and so if no previous operation has produced a tensor
|
// and so if no previous operation has produced a tensor
|
||||||
@ -863,7 +935,8 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
|||||||
OwningModuleRef tflite::FlatBufferToMlir(
|
OwningModuleRef tflite::FlatBufferToMlir(
|
||||||
absl::string_view buffer, MLIRContext* context, Location base_loc,
|
absl::string_view buffer, MLIRContext* context, Location base_loc,
|
||||||
const std::vector<std::string>& ordered_output_arrays,
|
const std::vector<std::string>& ordered_output_arrays,
|
||||||
bool use_external_constant) {
|
bool use_external_constant,
|
||||||
|
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
auto model_ptr =
|
auto model_ptr =
|
||||||
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
||||||
if (nullptr == model_ptr) {
|
if (nullptr == model_ptr) {
|
||||||
@ -918,7 +991,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
|||||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||||
builder, ordered_output_arrays,
|
builder, ordered_output_arrays,
|
||||||
/*is_entry_point=*/e.index() == 0,
|
/*is_entry_point=*/e.index() == 0,
|
||||||
/*use_external_constant=*/use_external_constant);
|
/*use_external_constant=*/use_external_constant,
|
||||||
|
experimental_prune_unreachable_nodes_unconditionally);
|
||||||
if (!func_or_error.ok()) {
|
if (!func_or_error.ok()) {
|
||||||
return emitError(base_loc, "could not translate function ")
|
return emitError(base_loc, "could not translate function ")
|
||||||
<< subgraph->name,
|
<< subgraph->name,
|
||||||
@ -931,9 +1005,10 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
|||||||
return OwningModuleRef(module);
|
return OwningModuleRef(module);
|
||||||
}
|
}
|
||||||
|
|
||||||
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
static OwningModuleRef FlatBufferFileToMlirTrans(
|
||||||
MLIRContext* context,
|
llvm::SourceMgr* source_mgr, MLIRContext* context,
|
||||||
bool use_external_constant) {
|
bool use_external_constant,
|
||||||
|
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
const llvm::MemoryBuffer* input =
|
const llvm::MemoryBuffer* input =
|
||||||
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
|
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
|
||||||
std::string error;
|
std::string error;
|
||||||
@ -950,12 +1025,14 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
|||||||
|
|
||||||
return tflite::FlatBufferToMlir(
|
return tflite::FlatBufferToMlir(
|
||||||
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
||||||
context, loc, outputs, use_external_constant);
|
context, loc, outputs, use_external_constant,
|
||||||
|
experimental_prune_unreachable_nodes_unconditionally);
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
||||||
"tflite-flatbuffer-to-mlir",
|
"tflite-flatbuffer-to-mlir",
|
||||||
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
|
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
|
||||||
return FlatBufferFileToMlirTrans(&source_mgr, context,
|
return FlatBufferFileToMlirTrans(
|
||||||
use_external_constant);
|
&source_mgr, context, use_external_constant,
|
||||||
|
experimental_prune_unreachable_nodes_unconditionally);
|
||||||
});
|
});
|
||||||
|
@ -31,11 +31,14 @@ namespace tflite {
|
|||||||
// on failure, and more specific errors will be emitted via the context.
|
// on failure, and more specific errors will be emitted via the context.
|
||||||
// If `use_external_constant` is true, it will create `tfl.external_const`
|
// If `use_external_constant` is true, it will create `tfl.external_const`
|
||||||
// instead of `tfl.const`.
|
// instead of `tfl.const`.
|
||||||
|
// If `experimental_prune_unreachable_nodes_unconditionally` is true, nodes that
|
||||||
|
// are not ancestors of the output nodes will be pruned.
|
||||||
mlir::OwningModuleRef FlatBufferToMlir(
|
mlir::OwningModuleRef FlatBufferToMlir(
|
||||||
absl::string_view buffer, mlir::MLIRContext* context,
|
absl::string_view buffer, mlir::MLIRContext* context,
|
||||||
mlir::Location base_loc,
|
mlir::Location base_loc,
|
||||||
const std::vector<std::string>& ordered_output_arrays,
|
const std::vector<std::string>& ordered_output_arrays,
|
||||||
bool use_external_constant = false);
|
bool use_external_constant = false,
|
||||||
|
bool experimental_prune_unreachable_nodes_unconditionally = false);
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
|
||||||
|
@ -11,6 +11,8 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
|||||||
%3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
|
%3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
|
||||||
// CHECK: %[[EXP:.*]] = "tfl.exp"
|
// CHECK: %[[EXP:.*]] = "tfl.exp"
|
||||||
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
|
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
|
||||||
|
// tfl.neg should not be pruned
|
||||||
|
// CHECK: %[[NEG:.*]] = "tfl.neg"
|
||||||
%5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
|
%5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
|
||||||
// CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
|
// CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
|
||||||
return %5 : tensor<4xf32>
|
return %5 : tensor<4xf32>
|
||||||
|
@ -0,0 +1,19 @@
|
|||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -output-arrays=mul,exp,div --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||||
|
// Confirm graph pruning.
|
||||||
|
|
||||||
|
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||||
|
^bb0(%arg0: tensor<4xf32>):
|
||||||
|
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||||
|
%1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
|
||||||
|
// CHECK: %[[MUL:.*]] = tfl.mul
|
||||||
|
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
|
||||||
|
// CHECK: %[[DIV:.*]] = tfl.div
|
||||||
|
%3 = "tfl.div"(%2, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
|
||||||
|
// CHECK: %[[EXP:.*]] = "tfl.exp"
|
||||||
|
%4 = "tfl.exp"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
|
||||||
|
// tfl.neg should be pruned
|
||||||
|
// CHECK-NOT: "tfl.neg"
|
||||||
|
%5 = "tfl.neg"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
|
||||||
|
// CHECK: return %[[MUL]], %[[EXP]], %[[DIV]]
|
||||||
|
return %5 : tensor<4xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user