[Cleanup] Some minor cleanups in flatbuffer_import.cc
PiperOrigin-RevId: 326477418 Change-Id: I8cb0c4f682d1fa194a413bb0b682a87fe24e8406
This commit is contained in:
parent
466275b90e
commit
f03dc8cff2
@ -254,20 +254,35 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
|
||||
layer_stats, axis_stats, axis);
|
||||
}
|
||||
|
||||
StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
|
||||
if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
||||
// Returns true if this is a basic LSTM op.
|
||||
bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
|
||||
if (const auto* op = op_union.AsLSTMOptions()) {
|
||||
return op->kernel_type == tflite::LSTMKernelType_BASIC;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Gets the MLIR op name with the dialect name for the flatbuffer operator.
|
||||
StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
|
||||
const tflite::OperatorCodeT& op_code) {
|
||||
if (IsBasicLSTMOp(op.builtin_options)) {
|
||||
return std::string("tfl.basic_lstm");
|
||||
}
|
||||
|
||||
if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
||||
return std::string("tfl.custom");
|
||||
}
|
||||
if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
|
||||
if (op_code.builtin_code == tflite::BuiltinOperator_IF) {
|
||||
return std::string("tf.If");
|
||||
}
|
||||
if (opcode.builtin_code == tflite::BuiltinOperator_WHILE) {
|
||||
if (op_code.builtin_code == tflite::BuiltinOperator_WHILE) {
|
||||
return std::string("tf.While");
|
||||
}
|
||||
|
||||
const char* op_name = tflite::EnumNameBuiltinOperator(opcode.builtin_code);
|
||||
std::string lowered_name = llvm::StringRef(op_name).lower();
|
||||
return llvm::Twine("tfl.", lowered_name).str();
|
||||
llvm::StringRef op_name(
|
||||
tflite::EnumNameBuiltinOperator(op_code.builtin_code));
|
||||
return llvm::Twine("tfl.", op_name.lower()).str();
|
||||
}
|
||||
|
||||
// The buffers in TFLite flatbuffers have their contents stored as a vector of
|
||||
@ -510,14 +525,6 @@ llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs(
|
||||
return {};
|
||||
}
|
||||
|
||||
// Returns true if this is a basic LSTM op.
|
||||
bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
|
||||
if (const auto* op = op_union.AsLSTMOptions()) {
|
||||
return op->kernel_type == tflite::LSTMKernelType_BASIC;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(krzysd) Handle function calls
|
||||
StatusOr<Operation*> ConvertOp(
|
||||
@ -525,7 +532,6 @@ StatusOr<Operation*> ConvertOp(
|
||||
const std::vector<mlir::TensorType>& intermediate_types,
|
||||
Value optional_arg_marker,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
|
||||
const std::vector<std::string>& op_names,
|
||||
const std::vector<std::string>& func_names,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
|
||||
OpBuilder builder) {
|
||||
@ -537,10 +543,10 @@ StatusOr<Operation*> ConvertOp(
|
||||
return emitError(loc, err.ToString()), err;
|
||||
}
|
||||
|
||||
const bool is_basic_lstm = IsBasicLSTMOp(op.builtin_options);
|
||||
const tflite::OperatorCodeT op_code = *op_codes.at(op.opcode_index);
|
||||
const std::string& op_name =
|
||||
is_basic_lstm ? "tfl.basic_lstm" : op_names.at(op.opcode_index);
|
||||
const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const std::string op_name, GetMlirOpName(op, op_code));
|
||||
|
||||
OperationState op_state(loc, op_name);
|
||||
|
||||
for (auto input_num : op.inputs) {
|
||||
@ -791,8 +797,7 @@ static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
|
||||
}
|
||||
|
||||
// Build a FuncOp from a tflite SubGraph
|
||||
// The op_names are a mapping from indexes into the TFLite operators array to
|
||||
// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
|
||||
// The buffers are directly taken
|
||||
// from the deserialized flatbuffer as we do not have the type information to
|
||||
// interpret them until this point. The base_loc parameter is the location of
|
||||
// the flatbuffer as a whole (usually a file). The is_entry_point flag
|
||||
@ -802,7 +807,6 @@ static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
|
||||
StatusOr<FuncOp> ConvertSubgraph(
|
||||
const tflite::SubGraphT& subgraph, llvm::StringRef name,
|
||||
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
|
||||
const std::vector<std::string>& op_names,
|
||||
const std::vector<std::string>& func_names,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
||||
Location base_loc, Builder builder, bool is_entry_point,
|
||||
@ -1002,8 +1006,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto* mlir_op,
|
||||
ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
|
||||
op_codes, op_names, func_names, subgraph.tensors, op_loc,
|
||||
op_builder));
|
||||
op_codes, func_names, subgraph.tensors, op_loc, op_builder));
|
||||
|
||||
// Add the results to the value maps. There are two cases: 1. the result
|
||||
// tensor does not have min/max values, the original op result is used
|
||||
@ -1079,17 +1082,6 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
|
||||
auto builder = Builder(context);
|
||||
|
||||
std::vector<std::string> operator_names;
|
||||
operator_names.reserve(model->operator_codes.size());
|
||||
|
||||
for (auto& opcode : model->operator_codes) {
|
||||
auto operator_name_or_error = OpNameForOpCode(*opcode);
|
||||
if (!operator_name_or_error.ok()) {
|
||||
return emitError(base_loc, operator_name_or_error.status().ToString()),
|
||||
nullptr;
|
||||
}
|
||||
operator_names.push_back(operator_name_or_error.ConsumeValueOrDie());
|
||||
}
|
||||
|
||||
std::vector<std::string> func_names;
|
||||
for (auto& subgraph : model->subgraphs) {
|
||||
@ -1110,8 +1102,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
auto& subgraph = e.value();
|
||||
std::string name = SubgraphName(e.index(), *subgraph);
|
||||
auto func_or_error = ConvertSubgraph(
|
||||
*subgraph, name, model->operator_codes, operator_names, func_names,
|
||||
model->buffers, base_loc, builder,
|
||||
*subgraph, name, model->operator_codes, func_names, model->buffers,
|
||||
base_loc, builder,
|
||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||
/*is_entry_point=*/e.index() == 0,
|
||||
/*use_external_constant=*/use_external_constant, ordered_input_arrays,
|
||||
|
Loading…
Reference in New Issue
Block a user