[Cleanup] Some minor cleanups in flatbuffer_import.cc

PiperOrigin-RevId: 326477418
Change-Id: I8cb0c4f682d1fa194a413bb0b682a87fe24e8406
This commit is contained in:
Jing Pu 2020-08-13 10:41:25 -07:00 committed by TensorFlower Gardener
parent 466275b90e
commit f03dc8cff2

View File

@ -254,20 +254,35 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
layer_stats, axis_stats, axis);
}
StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) {
// Returns true if this is a basic LSTM op.
bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
if (const auto* op = op_union.AsLSTMOptions()) {
return op->kernel_type == tflite::LSTMKernelType_BASIC;
} else {
return false;
}
}
// Gets the MLIR op name with the dialect name for the flatbuffer operator.
StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
const tflite::OperatorCodeT& op_code) {
if (IsBasicLSTMOp(op.builtin_options)) {
return std::string("tfl.basic_lstm");
}
if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) {
return std::string("tfl.custom");
}
if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
if (op_code.builtin_code == tflite::BuiltinOperator_IF) {
return std::string("tf.If");
}
if (opcode.builtin_code == tflite::BuiltinOperator_WHILE) {
if (op_code.builtin_code == tflite::BuiltinOperator_WHILE) {
return std::string("tf.While");
}
const char* op_name = tflite::EnumNameBuiltinOperator(opcode.builtin_code);
std::string lowered_name = llvm::StringRef(op_name).lower();
return llvm::Twine("tfl.", lowered_name).str();
llvm::StringRef op_name(
tflite::EnumNameBuiltinOperator(op_code.builtin_code));
return llvm::Twine("tfl.", op_name.lower()).str();
}
// The buffers in TFLite flatbuffers have their contents stored as a vector of
@ -510,14 +525,6 @@ llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs(
return {};
}
// Returns true if this is a basic LSTM op.
bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
if (const auto* op = op_union.AsLSTMOptions()) {
return op->kernel_type == tflite::LSTMKernelType_BASIC;
} else {
return false;
}
}
// TODO(krzysd) Handle function calls
StatusOr<Operation*> ConvertOp(
@ -525,7 +532,6 @@ StatusOr<Operation*> ConvertOp(
const std::vector<mlir::TensorType>& intermediate_types,
Value optional_arg_marker,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
const std::vector<std::string>& op_names,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
OpBuilder builder) {
@ -537,10 +543,10 @@ StatusOr<Operation*> ConvertOp(
return emitError(loc, err.ToString()), err;
}
const bool is_basic_lstm = IsBasicLSTMOp(op.builtin_options);
const tflite::OperatorCodeT op_code = *op_codes.at(op.opcode_index);
const std::string& op_name =
is_basic_lstm ? "tfl.basic_lstm" : op_names.at(op.opcode_index);
const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index);
TF_ASSIGN_OR_RETURN(const std::string op_name, GetMlirOpName(op, op_code));
OperationState op_state(loc, op_name);
for (auto input_num : op.inputs) {
@ -791,8 +797,7 @@ static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
}
// Build a FuncOp from a tflite SubGraph
// The op_names are a mapping from indexes into the TFLite operators array to
// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken
// The buffers are directly taken
// from the deserialized flatbuffer as we do not have the type information to
// interpret them until this point. The base_loc parameter is the location of
// the flatbuffer as a whole (usually a file). The is_entry_point flag
@ -802,7 +807,6 @@ static StatusOr<FuncOp> PostProcessFuncOp(FuncOp func) {
StatusOr<FuncOp> ConvertSubgraph(
const tflite::SubGraphT& subgraph, llvm::StringRef name,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
const std::vector<std::string>& op_names,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
Location base_loc, Builder builder, bool is_entry_point,
@ -1002,8 +1006,7 @@ StatusOr<FuncOp> ConvertSubgraph(
TF_ASSIGN_OR_RETURN(
auto* mlir_op,
ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
op_codes, op_names, func_names, subgraph.tensors, op_loc,
op_builder));
op_codes, func_names, subgraph.tensors, op_loc, op_builder));
// Add the results to the value maps. There are two cases: 1. the result
// tensor does not have min/max values, the original op result is used
@ -1079,17 +1082,6 @@ OwningModuleRef tflite::FlatBufferToMlir(
auto builder = Builder(context);
std::vector<std::string> operator_names;
operator_names.reserve(model->operator_codes.size());
for (auto& opcode : model->operator_codes) {
auto operator_name_or_error = OpNameForOpCode(*opcode);
if (!operator_name_or_error.ok()) {
return emitError(base_loc, operator_name_or_error.status().ToString()),
nullptr;
}
operator_names.push_back(operator_name_or_error.ConsumeValueOrDie());
}
std::vector<std::string> func_names;
for (auto& subgraph : model->subgraphs) {
@ -1110,8 +1102,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
auto& subgraph = e.value();
std::string name = SubgraphName(e.index(), *subgraph);
auto func_or_error = ConvertSubgraph(
*subgraph, name, model->operator_codes, operator_names, func_names,
model->buffers, base_loc, builder,
*subgraph, name, model->operator_codes, func_names, model->buffers,
base_loc, builder,
// TODO(b/131175224,b/132239787) Support multiple entry points
/*is_entry_point=*/e.index() == 0,
/*use_external_constant=*/use_external_constant, ordered_input_arrays,