diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 3c8bf26aa14..c46c4a7bfc2 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -254,20 +254,35 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, layer_stats, axis_stats, axis); } -StatusOr OpNameForOpCode(const tflite::OperatorCodeT opcode) { - if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) { +// Returns true if this is a basic LSTM op. +bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) { + if (const auto* op = op_union.AsLSTMOptions()) { + return op->kernel_type == tflite::LSTMKernelType_BASIC; + } else { + return false; + } +} + +// Gets the MLIR op name with the dialect name for the flatbuffer operator. +StatusOr GetMlirOpName(const tflite::OperatorT& op, + const tflite::OperatorCodeT& op_code) { + if (IsBasicLSTMOp(op.builtin_options)) { + return std::string("tfl.basic_lstm"); + } + + if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) { return std::string("tfl.custom"); } - if (opcode.builtin_code == tflite::BuiltinOperator_IF) { + if (op_code.builtin_code == tflite::BuiltinOperator_IF) { return std::string("tf.If"); } - if (opcode.builtin_code == tflite::BuiltinOperator_WHILE) { + if (op_code.builtin_code == tflite::BuiltinOperator_WHILE) { return std::string("tf.While"); } - const char* op_name = tflite::EnumNameBuiltinOperator(opcode.builtin_code); - std::string lowered_name = llvm::StringRef(op_name).lower(); - return llvm::Twine("tfl.", lowered_name).str(); + llvm::StringRef op_name( + tflite::EnumNameBuiltinOperator(op_code.builtin_code)); + return llvm::Twine("tfl.", op_name.lower()).str(); } // The buffers in TFLite flatbuffers have their contents stored as a vector of @@ -510,14 +525,6 @@ llvm::SmallVector ConvertSubgraphIdxsToFunctionAttrs( return {}; } -// Returns true if this is a basic LSTM op. -bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) { - if (const auto* op = op_union.AsLSTMOptions()) { - return op->kernel_type == tflite::LSTMKernelType_BASIC; - } else { - return false; - } -} // TODO(krzysd) Handle function calls StatusOr ConvertOp( @@ -525,7 +532,6 @@ StatusOr ConvertOp( const std::vector& intermediate_types, Value optional_arg_marker, const std::vector>& op_codes, - const std::vector& op_names, const std::vector& func_names, const std::vector>& tensors, Location loc, OpBuilder builder) { @@ -537,10 +543,10 @@ StatusOr ConvertOp( return emitError(loc, err.ToString()), err; } - const bool is_basic_lstm = IsBasicLSTMOp(op.builtin_options); - const tflite::OperatorCodeT op_code = *op_codes.at(op.opcode_index); - const std::string& op_name = - is_basic_lstm ? "tfl.basic_lstm" : op_names.at(op.opcode_index); + const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index); + + TF_ASSIGN_OR_RETURN(const std::string op_name, GetMlirOpName(op, op_code)); + OperationState op_state(loc, op_name); for (auto input_num : op.inputs) { @@ -791,8 +797,7 @@ static StatusOr PostProcessFuncOp(FuncOp func) { } // Build a FuncOp from a tflite SubGraph -// The op_names are a mapping from indexes into the TFLite operators array to -// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken +// The buffers are directly taken // from the deserialized flatbuffer as we do not have the type information to // interpret them until this point. The base_loc parameter is the location of // the flatbuffer as a whole (usually a file). The is_entry_point flag @@ -802,7 +807,6 @@ static StatusOr PostProcessFuncOp(FuncOp func) { StatusOr ConvertSubgraph( const tflite::SubGraphT& subgraph, llvm::StringRef name, const std::vector>& op_codes, - const std::vector& op_names, const std::vector& func_names, const std::vector>& buffers, Location base_loc, Builder builder, bool is_entry_point, @@ -1002,8 +1006,7 @@ StatusOr ConvertSubgraph( TF_ASSIGN_OR_RETURN( auto* mlir_op, ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker, - op_codes, op_names, func_names, subgraph.tensors, op_loc, - op_builder)); + op_codes, func_names, subgraph.tensors, op_loc, op_builder)); // Add the results to the value maps. There are two cases: 1. the result // tensor does not have min/max values, the original op result is used @@ -1079,17 +1082,6 @@ OwningModuleRef tflite::FlatBufferToMlir( auto builder = Builder(context); - std::vector operator_names; - operator_names.reserve(model->operator_codes.size()); - - for (auto& opcode : model->operator_codes) { - auto operator_name_or_error = OpNameForOpCode(*opcode); - if (!operator_name_or_error.ok()) { - return emitError(base_loc, operator_name_or_error.status().ToString()), - nullptr; - } - operator_names.push_back(operator_name_or_error.ConsumeValueOrDie()); - } std::vector func_names; for (auto& subgraph : model->subgraphs) { @@ -1110,8 +1102,8 @@ OwningModuleRef tflite::FlatBufferToMlir( auto& subgraph = e.value(); std::string name = SubgraphName(e.index(), *subgraph); auto func_or_error = ConvertSubgraph( - *subgraph, name, model->operator_codes, operator_names, func_names, - model->buffers, base_loc, builder, + *subgraph, name, model->operator_codes, func_names, model->buffers, + base_loc, builder, // TODO(b/131175224,b/132239787) Support multiple entry points /*is_entry_point=*/e.index() == 0, /*use_external_constant=*/use_external_constant, ordered_input_arrays,