[Cleanup] Some minor cleanups in flatbuffer_import.cc
PiperOrigin-RevId: 326477418 Change-Id: I8cb0c4f682d1fa194a413bb0b682a87fe24e8406
This commit is contained in:
parent
466275b90e
commit
f03dc8cff2
@ -254,20 +254,35 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
|
|||||||
layer_stats, axis_stats, axis);
|
layer_stats, axis_stats, axis);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
|
// Returns true if this is a basic LSTM op.
|
||||||
if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
|
||||||
|
if (const auto* op = op_union.AsLSTMOptions()) {
|
||||||
|
return op->kernel_type == tflite::LSTMKernelType_BASIC;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets the MLIR op name with the dialect name for the flatbuffer operator.
|
||||||
|
StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
|
||||||
|
const tflite::OperatorCodeT& op_code) {
|
||||||
|
if (IsBasicLSTMOp(op.builtin_options)) {
|
||||||
|
return std::string("tfl.basic_lstm");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
||||||
return std::string("tfl.custom");
|
return std::string("tfl.custom");
|
||||||
}
|
}
|
||||||
if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
|
if (op_code.builtin_code == tflite::BuiltinOperator_IF) {
|
||||||
return std::string("tf.If");
|
return std::string("tf.If");
|
||||||
}
|
}
|
||||||
if (opcode.builtin_code == tflite::BuiltinOperator_WHILE) {
|
if (op_code.builtin_code == tflite::BuiltinOperator_WHILE) {
|
||||||
return std::string("tf.While");
|
return std::string("tf.While");
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* op_name = tflite::EnumNameBuiltinOperator(opcode.builtin_code);
|
llvm::StringRef op_name(
|
||||||
std::string lowered_name = llvm::StringRef(op_name).lower();
|
tflite::EnumNameBuiltinOperator(op_code.builtin_code));
|
||||||
return llvm::Twine("tfl.", lowered_name).str();
|
return llvm::Twine("tfl.", op_name.lower()).str();
|
||||||
}
|
}
|
||||||
|
|
||||||
// The buffers in TFLite flatbuffers have their contents stored as a vector of
|
// The buffers in TFLite flatbuffers have their contents stored as a vector of
|
||||||
@ -510,14 +525,6 @@ llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs(
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if this is a basic LSTM op.
|
|
||||||
bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
|
|
||||||
if (const auto* op = op_union.AsLSTMOptions()) {
|
|
||||||
return op->kernel_type == tflite::LSTMKernelType_BASIC;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(krzysd) Handle function calls
|
// TODO(krzysd) Handle function calls
|
||||||
StatusOr<Operation*> ConvertOp(
|
StatusOr<Operation*> ConvertOp(
|
||||||
@ -525,7 +532,6 @@ StatusOr<Operation*> ConvertOp(
|
|||||||
const std::vector<mlir::TensorType>& intermediate_types,
|
const std::vector<mlir::TensorType>& intermediate_types,
|
||||||
Value optional_arg_marker,
|
Value optional_arg_marker,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
|
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
|
||||||
const std::vector<std::string>& op_names,
|
|
||||||
const std::vector<std::string>& func_names,
|
const std::vector<std::string>& func_names,
|
||||||
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
|
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
|
||||||
OpBuilder builder) {
|
OpBuilder builder) {
|
||||||
@ -537,10 +543,10 @@ StatusOr<Operation*> ConvertOp(
|
|||||||
return emitError(loc, err.ToString()), err;
|
return emitError(loc, err.ToString()), err;
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool is_basic_lstm = IsBasicLSTMOp(op.builtin_options);
|
const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index);
|
||||||
const tflite::OperatorCodeT op_code = *op_codes.at(op.opcode_index);
|
|
||||||
const std::string& op_name =
|
TF_ASSIGN_OR_RETURN(const std::string op_name, GetMlirOpName(op, op_code));
|
||||||
is_basic_lstm ? "tfl.basic_lstm" : op_names.at(op.opcode_index);
|
|
||||||
OperationState op_state(loc, op_name);
|
OperationState op_state(loc, op_name);
|
||||||
|
|
||||||
for (auto input_num : op.inputs) {
|
for (auto input_num : op.inputs) {
|
||||||
@ -791,8 +797,7 @@ static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build a FuncOp from a tflite SubGraph
|
// Build a FuncOp from a tflite SubGraph
|
||||||
// The op_names are a mapping from indexes into the TFLite operators array to
|
// The buffers are directly taken
|
||||||
// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
|
|
||||||
// from the deserialized flatbuffer as we do not have the type information to
|
// from the deserialized flatbuffer as we do not have the type information to
|
||||||
// interpret them until this point. The base_loc parameter is the location of
|
// interpret them until this point. The base_loc parameter is the location of
|
||||||
// the flatbuffer as a whole (usually a file). The is_entry_point flag
|
// the flatbuffer as a whole (usually a file). The is_entry_point flag
|
||||||
@ -802,7 +807,6 @@ static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
|
|||||||
StatusOr<FuncOp> ConvertSubgraph(
|
StatusOr<FuncOp> ConvertSubgraph(
|
||||||
const tflite::SubGraphT& subgraph, llvm::StringRef name,
|
const tflite::SubGraphT& subgraph, llvm::StringRef name,
|
||||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
|
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
|
||||||
const std::vector<std::string>& op_names,
|
|
||||||
const std::vector<std::string>& func_names,
|
const std::vector<std::string>& func_names,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
||||||
Location base_loc, Builder builder, bool is_entry_point,
|
Location base_loc, Builder builder, bool is_entry_point,
|
||||||
@ -1002,8 +1006,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto* mlir_op,
|
auto* mlir_op,
|
||||||
ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
|
ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
|
||||||
op_codes, op_names, func_names, subgraph.tensors, op_loc,
|
op_codes, func_names, subgraph.tensors, op_loc, op_builder));
|
||||||
op_builder));
|
|
||||||
|
|
||||||
// Add the results to the value maps. There are two cases: 1. the result
|
// Add the results to the value maps. There are two cases: 1. the result
|
||||||
// tensor does not have min/max values, the original op result is used
|
// tensor does not have min/max values, the original op result is used
|
||||||
@ -1079,17 +1082,6 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
|||||||
|
|
||||||
auto builder = Builder(context);
|
auto builder = Builder(context);
|
||||||
|
|
||||||
std::vector<std::string> operator_names;
|
|
||||||
operator_names.reserve(model->operator_codes.size());
|
|
||||||
|
|
||||||
for (auto& opcode : model->operator_codes) {
|
|
||||||
auto operator_name_or_error = OpNameForOpCode(*opcode);
|
|
||||||
if (!operator_name_or_error.ok()) {
|
|
||||||
return emitError(base_loc, operator_name_or_error.status().ToString()),
|
|
||||||
nullptr;
|
|
||||||
}
|
|
||||||
operator_names.push_back(operator_name_or_error.ConsumeValueOrDie());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> func_names;
|
std::vector<std::string> func_names;
|
||||||
for (auto& subgraph : model->subgraphs) {
|
for (auto& subgraph : model->subgraphs) {
|
||||||
@ -1110,8 +1102,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
|||||||
auto& subgraph = e.value();
|
auto& subgraph = e.value();
|
||||||
std::string name = SubgraphName(e.index(), *subgraph);
|
std::string name = SubgraphName(e.index(), *subgraph);
|
||||||
auto func_or_error = ConvertSubgraph(
|
auto func_or_error = ConvertSubgraph(
|
||||||
*subgraph, name, model->operator_codes, operator_names, func_names,
|
*subgraph, name, model->operator_codes, func_names, model->buffers,
|
||||||
model->buffers, base_loc, builder,
|
base_loc, builder,
|
||||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||||
/*is_entry_point=*/e.index() == 0,
|
/*is_entry_point=*/e.index() == 0,
|
||||||
/*use_external_constant=*/use_external_constant, ordered_input_arrays,
|
/*use_external_constant=*/use_external_constant, ordered_input_arrays,
|
||||||
|
Loading…
Reference in New Issue
Block a user