diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 446ba89a3f1..03cf9265f3b 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -224,6 +224,7 @@ cc_library( deps = [ ":tensorflow_lite_ops_inc_gen", ":validators", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/lite/schema:schema_fbs", "@llvm-project//llvm:support", @@ -553,14 +554,14 @@ cc_library( cc_library( name = "flatbuffer_translate_lib", srcs = [ - "flatbuffer_export.cc", "flatbuffer_import.cc", + "flatbuffer_translate.cc", "utils/convert_type.cc", ], hdrs = [ - "flatbuffer_export.h", - "flatbuffer_export_flags.h", "flatbuffer_import.h", + "flatbuffer_translate.h", + "flatbuffer_translate_flags.h", "utils/convert_type.h", ], deps = [ @@ -578,10 +579,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:logging", - "//tensorflow/core/platform:status", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:string_util", @@ -602,32 +601,15 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:Translation", ], -) - -cc_library( - name = "flatbuffer_translate_registeration", - srcs = [ - "flatbuffer_translate.cc", - ], - deps = [ - ":flatbuffer_translate_lib", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LoopOpsTransforms", - "@llvm-project//mlir:MlirTranslateMain", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Translation", - ], alwayslink = 1, ) tf_cc_binary( name = "flatbuffer_translate", deps = [ - ":flatbuffer_translate_registeration", + ":flatbuffer_translate_lib", + "@llvm-project//mlir:LoopOpsTransforms", + "@llvm-project//mlir:MlirTranslateMain", ], ) @@ -665,13 +647,10 @@ filegroup( tf_cc_binary( name = "tf_tfl_translate", - srcs = [ - ":tf_tfl_translate_main", - ], + srcs = [":tf_tfl_translate_main"], deps = [ ":common", ":flatbuffer_translate_lib", - ":flatbuffer_translate_registeration", ":tensorflow_lite", ":tf_tfl_passes", ":tf_tfl_translate_cl_options", @@ -693,18 +672,15 @@ tf_cc_binary( tf_cc_binary( name = "mlir-tflite-runner", - srcs = [ - "mlir_tflite_runner.cc", - ], + srcs = ["mlir_tflite_runner.cc"], deps = [ ":flatbuffer_translate_lib", - ":flatbuffer_translate_registeration", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "//tensorflow/lite:framework", "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc deleted file mode 100644 index 72e9b8c742a..00000000000 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ /dev/null @@ -1,1455 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "flatbuffers/flatbuffers.h" // TF:flatbuffers -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/ToolOutputFile.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // TF:llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Attributes.h" // TF:llvm-project -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "mlir/Translation.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" -#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" -#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/string_util.h" -#include "tensorflow/lite/tools/versioning/op_version.h" -#include "tensorflow/lite/tools/versioning/runtime_version.h" -#include "tensorflow/lite/version.h" - -using llvm::dyn_cast; -using llvm::formatv; -using llvm::isa; -using llvm::Optional; -using llvm::StringRef; -using llvm::Twine; -using mlir::Dialect; -using mlir::ElementsAttr; -using mlir::FuncOp; -using mlir::MLIRContext; -using mlir::ModuleOp; -using mlir::NoneType; -using mlir::Operation; -using mlir::Region; -using mlir::StringAttr; -using mlir::TensorType; -using mlir::Type; -using mlir::UnknownLoc; -using mlir::Value; -using tensorflow::OpOrArgLocNameMapper; -using tensorflow::OpOrArgNameMapper; -using tensorflow::Status; -using tflite::flex::IsWhitelistedFlexOp; -using xla::StatusOr; - -template -using BufferOffset = flatbuffers::Offset; - -template -using VectorBufferOffset = flatbuffers::Offset>; - -using CustomOptionsOffset = VectorBufferOffset; - -namespace error = tensorflow::error; -namespace tfl = mlir::TFL; - -ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; - -// Use initial buffer size in flatbuffer builder to be same as the initial size -// used by the TOCO export. (It does not explain rationale for this choice.) -constexpr size_t kInitialBufferSize = 10240; - -// Set `isSigned` to false if the `type` is an 8-bit unsigned integer type. -// Since tflite doesn't support unsigned for other types, returns error if -// `isSigned` is set to false for other types. -static StatusOr GetTFLiteType(Type type, - bool is_signed = true) { - if (!is_signed && type.isSignlessInteger(8)) { - return tflite::TensorType_UINT8; - } - if (!is_signed) { - return Status(error::INVALID_ARGUMENT, - "'isSigned' can only be set for 8-bits integer type"); - } - switch (type.getKind()) { - case mlir::StandardTypes::F32: - return tflite::TensorType_FLOAT32; - case mlir::StandardTypes::F16: - return tflite::TensorType_FLOAT16; - case mlir::TF::TensorFlowTypes::STRING: - return tflite::TensorType_STRING; - case mlir::TF::TensorFlowTypes::QUINT8: - return tflite::TensorType_UINT8; - case mlir::StandardTypes::Complex: { - auto ftype = type.cast().getElementType(); - if (ftype && ftype.isF32()) { - return tflite::TensorType_COMPLEX64; - } - return Status(error::INVALID_ARGUMENT, "Unsupported type"); - } - case mlir::StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - return tflite::TensorType_BOOL; - case 8: - return itype.isUnsigned() ? tflite::TensorType_UINT8 - : tflite::TensorType_INT8; - case 16: - return tflite::TensorType_INT16; - case 32: - return tflite::TensorType_INT32; - case 64: - return tflite::TensorType_INT64; - } - } - case mlir::quant::QuantizationTypes::UniformQuantized: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); - } - case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); - } - case mlir::TF::TensorFlowTypes::RESOURCE: { - // Treat tf.resource values as integer values in flatbuffer. - // TODO(b/146131919): Maybe need to have a detailed design for supporting - // other resource types beyonds hash table resources and resource - // variables. - return tflite::TensorType_INT32; - } - default: - // TFLite export fills FLOAT32 for unknown data types. Returning an error - // for now for safety and this could be revisited when required. - return Status(error::INVALID_ARGUMENT, "Unsupported type"); - } -} - -static bool IsConst(Operation* op) { - return isa(op) || isa(op) || - isa(op) || isa(op); -} - -template -static bool HasValidTFLiteType(Value value, T& error_handler) { - // None type is allowed to represent unspecified operands. - if (value.getType().isa()) return true; - - auto type = value.getType().dyn_cast(); - if (!type) { - if (auto op = value.getDefiningOp()) { - error_handler.emitError() - << '\'' << op << "' should produce value of tensor type instead of " - << value.getType(); - return false; - } - error_handler.emitError("expected tensor type, got ") << value.getType(); - return false; - } - - Type element_type = type.getElementType(); - auto status = GetTFLiteType(element_type); - if (!status.ok()) { - return error_handler.emitError( - formatv("Failed to convert element type '{0}': {1}", - element_type, status.status().error_message())), - false; - } - return true; -} - -// Returns true if the module holds all the invariants expected by the -// Translator class. -// TODO(hinsu): Now that translation is done by making a single pass over the -// MLIR module, consider inlining these validation checks at the place where -// these invariants are assumed instead of checking upfront. -static bool IsValidTFLiteMlirModule(ModuleOp module) { - MLIRContext* context = module.getContext(); - - // Verify that module has a function named main. - FuncOp main_fn = module.lookupSymbol("main"); - if (!main_fn) { - return emitError(UnknownLoc::get(context), - "should have a function named 'main'"), - false; - } - - for (auto fn : module.getOps()) { - if (fn.getBlocks().size() != 1) { - return fn.emitError("should have exactly one basic block"), false; - } - auto& bb = fn.getBlocks().front(); - - for (auto arg : bb.getArguments()) { - if (!HasValidTFLiteType(arg, fn)) - return fn.emitError("invalid TFLite type: ") << arg.getType(), false; - } - - // Verify that all operations except the terminator have exactly one - // result of type supported by TFLite. - for (auto& inst : bb) { - if (inst.isKnownTerminator()) break; - - for (auto result : inst.getResults()) { - if (!HasValidTFLiteType(result, inst)) - return fn.emitError("invalid TFLite type: ") << result.getType(), - false; - } - } - } - - return true; -} - -static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef( - ::mlir::Operation* inst) { - // We pass empty string for the original node_def name since Flex runtime - // does not care about this being set correctly on node_def. There is no - // "easy" (see b/120948529) way yet to get this from MLIR inst. - auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef( - inst, /*name=*/"", /*ignore_unregistered_attrs=*/true); - if (!status_or_node_def.ok()) { - inst->emitOpError( - Twine("failed to obtain TensorFlow nodedef with status: " + - status_or_node_def.status().ToString())); - return {}; - } - return std::move(status_or_node_def.ValueOrDie()); -} - -// Converts a mlir padding StringRef to TfLitePadding. -// Returns llvm::None if conversion fails. -static Optional GetTflitePadding(Operation* inst, - llvm::StringRef padding) { - const tflite::Padding padding_attr = - std::move(llvm::StringSwitch(padding) - .Case("SAME", tflite::Padding_SAME) - .Case("VALID", tflite::Padding_VALID)); - if (padding_attr == tflite::Padding_SAME) { - return kTfLitePaddingSame; - } - if (padding_attr == tflite::Padding_VALID) { - return kTfLitePaddingValid; - } - - return inst->emitOpError() << "Invalid padding attribute: " << padding, - llvm::None; -} - -// Extracts TfLitePoolParams from a TFL custom op. -// Template parameter, TFLOp, should be a TFL custom op containing attributes -// generated from TfLitePoolParams. -// Returns llvm::None if conversion fails. -template -static Optional GetTflitePoolParams(Operation* inst, - TFLOp op) { - TfLitePoolParams pool_params; - pool_params.stride_height = op.stride_h().getSExtValue(); - pool_params.stride_width = op.stride_w().getSExtValue(); - pool_params.filter_height = op.filter_h().getSExtValue(); - pool_params.filter_width = op.filter_w().getSExtValue(); - const auto padding = GetTflitePadding(inst, op.padding()); - if (padding) { - pool_params.padding = *padding; - pool_params.activation = kTfLiteActNone; - pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0}; - return pool_params; - } - - return llvm::None; -} - -namespace { - -// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. -class Translator { - public: - // Translates the given MLIR module into TFLite FlatBuffer format and returns - // the serialized output. Returns llvm::None on unsupported, invalid inputs or - // internal error. - static Optional Translate( - ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper); - - private: - enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; - explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, - bool emit_select_tf_ops, bool emit_custom_ops, - OpOrArgNameMapper* op_or_arg_name_mapper) - : module_(module), - name_mapper_(*op_or_arg_name_mapper), - builder_(kInitialBufferSize) { - // The first buffer must be empty according to the schema definition. - empty_buffer_ = tflite::CreateBuffer(builder_); - buffers_.push_back(empty_buffer_); - if (emit_builtin_tflite_ops) { - enabled_op_types_.emplace(OpType::kTfliteBuiltin); - } - if (emit_select_tf_ops) { - enabled_op_types_.emplace(OpType::kSelectTf); - } - if (emit_custom_ops) { - enabled_op_types_.emplace(OpType::kCustomOp); - } - tf_dialect_ = module.getContext()->getRegisteredDialect("tf"); - tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl"); - } - - Optional TranslateInternal(); - - // Returns TFLite buffer populated with constant value if the operation is - // TFLite constant operation. Otherwise, returns an empty buffer. Emits error - // and returns llvm::None on failure. - Optional> BuildBuffer(Operation* inst); - - // Build TFLite tensor from the given type. This function is for tfl.lstm - // intermediates, which should have UniformQuantizedType. - Optional> BuildTensorFromType( - mlir::Type type, const std::string& name); - - // Builds TFLite tensor from the given value. `buffer_idx` is index of the - // corresponding buffer. Emits error and returns llvm::None on failure. - Optional> BuildTensor(Value value, - const std::string& name, - unsigned buffer_idx); - - // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove - // these 2 functions here. - BufferOffset BuildIfOperator( - mlir::TF::IfOp op, const std::vector& operands, - const std::vector& results); - BufferOffset BuildWhileOperator( - mlir::TF::WhileOp op, const std::vector& operands, - const std::vector& results); - - // Build while operator where cond & body are regions. - Optional> BuildWhileOperator( - mlir::TFL::WhileOp op, const std::vector& operands, - const std::vector& results); - - // Builds custom operators. - // Templated on a) data type of custom_option to be stored into flatbuffer, - // and b) TFL custom op type. - template - BufferOffset BuildCustomOperator( - const CustomOptionType& custom_option, const std::string& opcode_name, - TFLOp op, const std::vector& operands, - const std::vector& results); - - BufferOffset BuildNumericVerifyOperator( - mlir::TFL::NumericVerifyOp op, const std::vector& operands, - const std::vector& results); - Optional> - BuildConvolution2DTransposeBiasOperator( - Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, - const std::vector& operands, - const std::vector& results); - Optional> BuildMaxPoolingWithArgMax2DOperator( - Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, - const std::vector& operands, - const std::vector& results); - Optional> BuildMaxUnpooling2DOperator( - Operation* inst, mlir::TFL::MaxUnpooling2DOp op, - const std::vector& operands, - const std::vector& results); - - Optional CreateFlexOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - Optional CreateCustomOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - std::unique_ptr CreateFlexBuilderWithNodeAttrs( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - - // Returns opcode index for op identified by the op_name, if already - // available. Otherwise, creates a new OperatorCode using the given `builtin` - // operator and associates it with `op_name`. - uint32_t GetOpcodeIndex(const std::string& op_name, - tflite::BuiltinOperator builtin); - - // Builds operator for the given operation with specified operand and result - // tensor indices. Emits an error and returns llvm::None on failure. - Optional> BuildOperator( - Operation* inst, const std::vector& operands, - const std::vector& results, - const std::vector& intermediates); - - // Build a subgraph with a given name out of the region either corresponding - // to a function's body or while op. - Optional> BuildSubGraph( - const std::string& name, Region* region); - - // Builds Metadata with the given `name` and buffer `content`. - BufferOffset BuildMetadata(StringRef name, - StringRef content); - - // Encodes the `tfl.metadata` dictionary attribute of the module to the - // metadata section in the final model. - Optional>> - CreateMetadataVector(); - - // Uses the tf.entry_function attribute (if set) to initialize the op to name - // mapping. - void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); - - // Determines if the specified operation op's operand at operand_index - // is marked as a stateful operand. - bool IsStatefulOperand(mlir::Operation* op, int operand_index); - - // Returns a unique name for `val`. - std::string UniqueName(mlir::Value val); - - ModuleOp module_; - - tensorflow::OpOrArgNameMapper& name_mapper_; - - flatbuffers::FlatBufferBuilder builder_; - BufferOffset empty_buffer_; - - std::vector> buffers_; - - // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. - absl::flat_hash_map opcode_index_map_; - std::vector> opcodes_; - - // Maps function name to index of the corresponding subgraph in the FlatBuffer - // model. - absl::flat_hash_map subgraph_index_map_; - absl::flat_hash_set enabled_op_types_; - - // Points to TensorFlow and TFLite dialects, respectively. nullptr if the - // dialect is not registered. - const Dialect* tf_dialect_; - const Dialect* tfl_dialect_; - - // The failed ops during legalization. - std::set failed_flex_ops_; - std::set failed_custom_ops_; -}; - -std::string Translator::UniqueName(mlir::Value val) { - return std::string(name_mapper_.GetUniqueName(val)); -} - -Optional> Translator::BuildBuffer( - Operation* inst) { - ElementsAttr attr; - if (auto cst = dyn_cast(inst)) { - // ConstantOp have ElementAttr at this point due to validation of the TFLite - // module. - attr = cst.getValue().cast(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); - } else { - return empty_buffer_; - } - - tensorflow::Tensor tensor; - auto status = tensorflow::ConvertToTensor(attr, &tensor); - if (!status.ok()) { - inst->emitError( - Twine("failed to convert value attribute to tensor with error: " + - status.ToString())); - return llvm::None; - } - - // TensorFlow and TensorFlow Lite use different string encoding formats. - // Convert to TensorFlow Lite format is it's a constant string tensor. - if (tensor.dtype() == tensorflow::DT_STRING) { - ::tflite::DynamicBuffer dynamic_buffer; - auto flat = tensor.flat<::tensorflow::tstring>(); - for (int i = 0; i < flat.size(); ++i) { - const auto& str = flat(i); - dynamic_buffer.AddString(str.c_str(), str.length()); - } - char* tensor_buffer; - int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer); - auto buffer_data = - builder_.CreateVector(reinterpret_cast(tensor_buffer), bytes); - free(tensor_buffer); - return tflite::CreateBuffer(builder_, buffer_data); - } - - absl::string_view tensor_data = tensor.tensor_data(); - auto buffer_data = builder_.CreateVector( - reinterpret_cast(tensor_data.data()), tensor_data.size()); - return tflite::CreateBuffer(builder_, buffer_data); -} - -Optional> Translator::BuildTensorFromType( - mlir::Type type, const std::string& name) { - auto tensor_type = type.cast(); - - if (!tensor_type.hasStaticShape()) { - return llvm::None; - } - llvm::ArrayRef shape_ref = tensor_type.getShape(); - std::vector shape(shape_ref.begin(), shape_ref.end()); - - auto element_type = tensor_type.getElementType(); - tflite::TensorType tflite_element_type = - GetTFLiteType(tensor_type.getElementType()).ValueOrDie(); - BufferOffset q_params; - auto qtype = element_type.dyn_cast(); - if (!qtype) { - return llvm::None; - } - q_params = tflite::CreateQuantizationParameters( - builder_, /*min=*/0, /*max=*/0, - builder_.CreateVector({static_cast(qtype.getScale())}), - builder_.CreateVector({qtype.getZeroPoint()})); - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - /*buffer=*/0, builder_.CreateString(name), q_params, - /*is_variable=*/false); -} - -Optional> Translator::BuildTensor( - Value value, const std::string& name, unsigned buffer_idx) { - auto type = value.getType().cast(); - - // TFLite requires tensor shape only for the inputs and constants. - // However, we output all known shapes for better round-tripping - auto check_shape = - [&](llvm::ArrayRef shape_ref) -> mlir::LogicalResult { - auto is_out_of_range = [](int64_t dim) { - return dim > std::numeric_limits::max(); - }; - - if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) - return mlir::emitError( - value.getLoc(), - "result shape dimensions out of 32 bit int type range"); - - return mlir::success(); - }; - - std::vector shape; - std::vector shape_signature; - if (type.hasStaticShape()) { - llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } else if (auto* inst = value.getDefiningOp()) { - if (IsConst(inst)) { - // Const op can have a result of dynamic shaped type (e.g. due to constant - // folding), but we can still derive the shape of a constant tensor for - // its attribute type. - mlir::Attribute tensor_attr = inst->getAttr("value"); - llvm::ArrayRef shape_ref = - tensor_attr.getType().cast().getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } - } else if (type.hasRank()) { - llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; - - shape.reserve(shape_ref.size()); - for (auto& dim : shape_ref) { - shape.push_back(dim == -1 ? 1 : dim); - } - shape_signature = std::vector(shape_ref.begin(), shape_ref.end()); - } - - if (auto* inst = value.getDefiningOp()) { - if (auto cst = dyn_cast(inst)) { - // CreateSparsityParameters(cst.s_param()); - } else if (auto cst = dyn_cast(inst)) { - // CreateSparsityParameters(cst.s_param()); - } - } - - Type element_type = type.getElementType(); - tflite::TensorType tflite_element_type = - GetTFLiteType(type.getElementType()).ValueOrDie(); - - BufferOffset q_params; - if (auto qtype = element_type.dyn_cast()) { - q_params = tflite::CreateQuantizationParameters( - // TODO(fengliuai): min and max values are not stored in the - // quantized type, so both are set to 0. The model couldn't be imported - // to TensorFlow because of this. - builder_, /*min=*/0, /*max=*/0, - builder_.CreateVector({static_cast(qtype.getScale())}), - builder_.CreateVector({qtype.getZeroPoint()})); - } else if (auto qtype = - element_type - .dyn_cast()) { - std::vector scales(qtype.getScales().begin(), - qtype.getScales().end()); - q_params = tflite::CreateQuantizationParameters( - builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), - builder_.CreateVector(qtype.getZeroPoints()), - tflite::QuantizationDetails_NONE, /*details=*/0, - qtype.getQuantizedDimension()); - } else { - q_params = tflite::CreateQuantizationParameters(builder_); - } - // Check if the value's uses includes an op and usage at an operand index - // marked as a stateful. If so, set the tensor's is_variable as true - // This is v1 ref variable semantics in the TFLite runtime. - bool is_variable = false; - for (auto& use : value.getUses()) { - is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); - if (is_variable) { - break; - } - } - - if (shape_signature.empty()) { - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, - /*is_variable=*/is_variable); - } else { - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, - /*is_variable=*/is_variable, /*sparsity=*/0, - /*shape_signature=*/builder_.CreateVector(shape_signature)); - } -} - -BufferOffset Translator::BuildIfOperator( - mlir::TF::IfOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); - int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); - int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); - auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, - else_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_IfOptions, - builtin_options); -} - -BufferOffset Translator::BuildWhileOperator( - mlir::TF::WhileOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); - int cond_subgraph_index = subgraph_index_map_.at(op.cond().str()); - int body_subgraph_index = subgraph_index_map_.at(op.body().str()); - auto builtin_options = tflite::CreateWhileOptions( - builder_, cond_subgraph_index, body_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_WhileOptions, - builtin_options); -} - -Optional> Translator::BuildWhileOperator( - mlir::TFL::WhileOp op, const std::vector& operands, - const std::vector& results) { - auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); - auto get_call_index = [&](mlir::Block& b) -> Optional { - if (b.getOperations().size() != 2) return llvm::None; - if (auto call_op = dyn_cast(b.front())) - return subgraph_index_map_.at(call_op.callee().str()); - return llvm::None; - }; - auto body_subgraph_index = get_call_index(op.body().front()); - auto cond_subgraph_index = get_call_index(op.cond().front()); - if (!body_subgraph_index || !cond_subgraph_index) - return op.emitOpError("only single call cond/body while export supported"), - llvm::None; - auto builtin_options = - tflite::CreateWhileOptions(builder_, *cond_subgraph_index, - *body_subgraph_index) - .Union(); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_WhileOptions, - builtin_options); -} - -template -BufferOffset Translator::BuildCustomOperator( - const CustomOptionType& custom_option, const std::string& opcode_name, - TFLOp op, const std::vector& operands, - const std::vector& results) { - std::vector custom_option_vector(sizeof(CustomOptionType)); - memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType)); - auto opcode_index = - GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM); - return tflite::CreateOperator( - builder_, opcode_index, builder_.CreateVector(operands), - builder_.CreateVector(results), tflite::BuiltinOptions_NONE, - /*builtin_options=*/0, - builder_.CreateVector(custom_option_vector), - tflite::CustomOptionsFormat_FLEXBUFFERS); -} - -BufferOffset Translator::BuildNumericVerifyOperator( - mlir::TFL::NumericVerifyOp op, const std::vector& operands, - const std::vector& results) { - float tolerance = op.tolerance().convertToFloat(); - return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results); -} - -Optional> -Translator::BuildConvolution2DTransposeBiasOperator( - Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, - const std::vector& operands, const std::vector& results) { - TfLiteTransposeConvParams conv_params; - conv_params.stride_height = op.stride_h().getSExtValue(); - conv_params.stride_width = op.stride_w().getSExtValue(); - const auto padding = GetTflitePadding(inst, op.padding()); - if (padding) { - conv_params.padding = *padding; - return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op, - operands, results); - } - - return llvm::None; -} - -Optional> -Translator::BuildMaxPoolingWithArgMax2DOperator( - Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, - const std::vector& operands, const std::vector& results) { - const auto pool_params = GetTflitePoolParams(inst, op); - if (pool_params) { - return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op, - operands, results); - } - - return llvm::None; -} - -Optional> -Translator::BuildMaxUnpooling2DOperator(Operation* inst, - mlir::TFL::MaxUnpooling2DOp op, - const std::vector& operands, - const std::vector& results) { - const auto pool_params = GetTflitePoolParams(inst, op); - if (pool_params) { - return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands, - results); - } - - return llvm::None; -} - -Optional Translator::CreateFlexOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - std::string node_def_str; - if (!node_def.SerializeToString(&node_def_str)) { - return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; - } - - auto flex_builder = absl::make_unique(); - flex_builder->Vector([&]() { - flex_builder->String(node_def.op()); - flex_builder->String(node_def_str); - }); - flex_builder->Finish(); - return builder_.CreateVector(flex_builder->GetBuffer()); -} - -Optional Translator::CreateCustomOpCustomOptions( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - std::string node_def_str; - if (!node_def.SerializeToString(&node_def_str)) { - return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; - } - auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); - return builder_.CreateVector(flex_builder->GetBuffer()); -} - -std::unique_ptr -Translator::CreateFlexBuilderWithNodeAttrs( - const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - auto flex_builder = absl::make_unique(); - size_t map_start = flex_builder->StartMap(); - for (const auto& pair : node_def.attr()) { - const char* key = pair.first.c_str(); - const auto& attr = pair.second; - switch (attr.value_case()) { - case ::tensorflow::AttrValue::kS: - flex_builder->String(key, attr.s()); - break; - case ::tensorflow::AttrValue::kType: { - auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type()); - if (status_or_tfl_type.ok()) { - flex_builder->Int(key, status_or_tfl_type.ValueOrDie()); - } else { - emitWarning(loc, "ignoring unsupported tensorflow type: ") - << std::to_string(attr.type()); - } - break; - } - case ::tensorflow::AttrValue::kI: - flex_builder->Int(key, attr.i()); - break; - case ::tensorflow::AttrValue::kF: - flex_builder->Float(key, attr.f()); - break; - case ::tensorflow::AttrValue::kB: - flex_builder->Bool(key, attr.b()); - break; - case tensorflow::AttrValue::kList: - if (attr.list().s_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const std::string& v : attr.list().s()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else if (attr.list().i_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const int64_t v : attr.list().i()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else if (attr.list().f_size() > 0) { - auto start = flex_builder->StartVector(key); - for (const float v : attr.list().f()) { - flex_builder->Add(v); - } - flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); - } else { - emitWarning(loc, - "ignoring unsupported type in list attribute with key: ") - << key; - } - break; - default: - emitWarning(loc, "ignoring unsupported attribute type with key: ") - << key; - break; - } - } - flex_builder->EndMap(map_start); - flex_builder->Finish(); - return flex_builder; -} - -uint32_t Translator::GetOpcodeIndex(const std::string& op_name, - tflite::BuiltinOperator builtin) { - auto it = opcode_index_map_.insert({op_name, 0}); - - // If the insert succeeded, the opcode has not been created already. Create a - // new operator code and update its index value in the map. - if (it.second) { - it.first->second = opcodes_.size(); - auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM - ? builder_.CreateString(op_name) - : BufferOffset(); - // Use version 0 for builtin op. This is a way to serialize version field to - // flatbuffer (since 0 is non default) and it will be corrected later. - int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1; - opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin, - custom_code, op_version)); - } - return it.first->second; -} - -Optional> Translator::BuildOperator( - Operation* inst, const std::vector& operands, - const std::vector& results, - const std::vector& intermediates) { - const auto* dialect = inst->getDialect(); - if (!dialect) { - inst->emitOpError("dialect is not registered"); - return llvm::None; - } - - // If TFLite built in op, create operator as a builtin op. - if (dialect == tfl_dialect_) { - // Only if built-in TFLite op emission is enabled, would legalization have - // converted any TF->TFL. - if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) { - return inst->emitOpError( - "is a TFLite builtin op but builtin emission is not enabled"), - llvm::None; - } - - auto builtin_code = GetBuiltinOpCode(inst); - if (!builtin_code) { - if (auto verify_op = dyn_cast(inst)) { - return BuildNumericVerifyOperator(verify_op, operands, results); - } - if (auto conv_transpose_bias_op = - dyn_cast(inst)) { - return BuildConvolution2DTransposeBiasOperator( - inst, conv_transpose_bias_op, operands, results); - } - if (auto max_pooling_with_arg_max_op = - dyn_cast(inst)) { - return BuildMaxPoolingWithArgMax2DOperator( - inst, max_pooling_with_arg_max_op, operands, results); - } - if (auto max_unpooling_op = dyn_cast(inst)) { - return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands, - results); - } - if (auto whileOp = dyn_cast(inst)) { - if (inst->getNumOperands() != inst->getNumResults()) { - inst->emitOpError( - "number of operands and results don't match, only canonical " - "TFL While supported"); - return llvm::None; - } - return BuildWhileOperator(whileOp, operands, results); - } - - inst->emitOpError("is not a supported TFLite op"); - return llvm::None; - } - - std::string op_name = inst->getName().getStringRef().str(); - uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code); - auto offset = CreateFlatBufferOperator(inst, opcode_index, operands, - results, intermediates, &builder_); - if (!offset) { - inst->emitOpError("is not a supported TFLite op"); - } - return offset; - } - - if (dialect == tf_dialect_) { - std::string op_name; - if (auto ifOp = dyn_cast(inst)) { - return BuildIfOperator(ifOp, operands, results); - } else if (auto whileOp = dyn_cast(inst)) { - return BuildWhileOperator(whileOp, operands, results); - } - - CustomOptionsOffset custom_options; - - // Ops in TF dialect can either be custom ops or flex ops. - // The reason we go directly from TensorFlow dialect MLIR to tensorflow - // node instead of going to TF table gen'd ops via generated code is that - // we do not want to restrict custom and flex op conversion support to - // only those TF ops that are currently registered in MLIR. The current - // model is of an open op system. - // - // The following algorithm is followed: - // if flex is enabled and the op is whitelisted as flex - // we emit op as flex. - // if custom is enabled - // we emit the op as custom. - auto node_def = GetTensorFlowNodeDef(inst); - if (!node_def) { - return llvm::None; - } - - // Flex op case - // Eventually, the whitelist will go away and we will rely on some TF op - // trait (e.g. No side effect) to determine if it is a supported "Flex" - // op or not. - if (enabled_op_types_.contains(OpType::kSelectTf) && - IsWhitelistedFlexOp(node_def->op())) { - // Construct ops as flex op encoding TensorFlow node definition - // as custom options. - // Flex ops are named with the kFlexOpNamePrefix prefix to the actual - // TF op name. - op_name = std::string(kFlexOpNamePrefix) + node_def->op(); - if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } - } else if (enabled_op_types_.contains(OpType::kCustomOp)) { - // Generic case of custom ops - write using flex buffers since that - // is the only custom options supported by TFLite today. - op_name = node_def->op(); - if (auto options = - CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } - } else { - // Create description of operation that could not be converted. - const int kLargeElementsAttr = 16; - std::string op_str; - llvm::raw_string_ostream os(op_str); - inst->getName().print(os); - // Print out attributes except for large elementsattributes (which should - // rarely be the cause why the legalization didn't happen). - if (!inst->getAttrList().getAttrs().empty()) { - os << " {"; - bool first = true; - for (auto& named_attr : inst->getAttrList().getDictionary()) { - os << (!first ? ", " : ""); - first = false; - named_attr.first.print(os); - os << " = "; - if (auto element_attr = named_attr.second.dyn_cast()) { - if (element_attr.getNumElements() <= kLargeElementsAttr) { - element_attr.print(os); - } else { - os << ""; - } - } else { - named_attr.second.print(os); - } - } - os << "}"; - } - - // Insert failed op to `flex_ops` or `custom_ops`. - if (IsWhitelistedFlexOp(node_def->op())) { - failed_flex_ops_.insert(os.str()); - } else { - failed_custom_ops_.insert(os.str()); - } - return inst->emitOpError("is neither a custom op nor a flex op"), - llvm::None; - } - - uint32_t opcode_index = - GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM); - auto inputs = builder_.CreateVector(operands); - auto outputs = builder_.CreateVector(results); - - return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, - tflite::BuiltinOptions_NONE, - /*builtin_options=*/0, - /*custom_options=*/custom_options, - tflite::CustomOptionsFormat_FLEXBUFFERS, - /*mutating_variable_inputs=*/0); - } - - return inst->emitOpError( - "is not any of a builtin TFLite op, a flex TensorFlow op or a " - "custom TensorFlow op"), - llvm::None; -} - -void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { - auto dict_attr = fn.getAttrOfType("tf.entry_function"); - if (!dict_attr) return; - - llvm::SmallVector input_names; - llvm::SmallVector output_names; - if (auto str = dict_attr.get("inputs").dyn_cast_or_null()) { - str.getValue().split(input_names, ',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - if (input_names.size() != fn.getNumArguments()) { - fn.emitWarning() << "invalid entry function specification"; - return; - } - for (auto it : llvm::enumerate(fn.getArguments())) { - name_mapper_.InitOpName(it.value(), input_names[it.index()].trim()); - } - *has_input_attr = true; - } - - if (auto str = - dict_attr.get("outputs").dyn_cast_or_null()) { - str.getValue().split(output_names, ',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - auto term = fn.getBlocks().back().getTerminator(); - if (output_names.size() != term->getNumOperands()) { - fn.emitWarning() << "output names (" << output_names.size() - << ") != terminator operands (" << term->getNumOperands() - << ")"; - return; - } - for (const auto& it : llvm::enumerate(term->getOperands())) { - name_mapper_.InitOpName(it.value(), output_names[it.index()].trim()); - } - } -} - -bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { - std::vector operand_indices; - if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false; - return absl::c_find(operand_indices, operand_index) != operand_indices.end(); -} - -Optional> Translator::BuildSubGraph( - const std::string& name, Region* region) { - bool has_input_attr = false; - if (auto fn = dyn_cast(region->getParentOp())) { - InitializeNamesFromAttribute(fn, &has_input_attr); - } - std::vector> tensors; - llvm::DenseMap tensor_index_map; - - // Builds tensor and buffer for argument or operation result. Returns false - // on failure. - auto build_tensor_and_buffer = [&](Value value, const std::string& name) { - // NoneType represents optional and may be skipped here. - if (value.getType().isa()) { - return true; - } - - tensor_index_map.insert({value, tensors.size()}); - auto tensor_or = BuildTensor(value, name, buffers_.size()); - if (!tensor_or) return false; - tensors.push_back(*tensor_or); - - // TODO(ashwinm): Check if for stateful tensors, if it is also needed to - // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. - // This does not seem to affect runtime behavior for RNN/LSTM, but would be - // good for reducing memory footprint. - if (auto* inst = value.getDefiningOp()) { - auto buffer_or = BuildBuffer(inst); - if (!buffer_or) return false; - buffers_.push_back(*buffer_or); - } else { - buffers_.push_back(empty_buffer_); - } - return true; - }; - - std::vector> operators; - auto& bb = region->front(); - - // Main function's arguments are first passed to `input` op so they don't - // have associated tensor and buffer. Build FlatBuffer tensor and buffer for - // other functions. - for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { - mlir::BlockArgument arg = bb.getArgument(i); - std::string name; - if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg)); - if (name.empty()) name = absl::StrCat("arg", i); - if (!build_tensor_and_buffer(arg, name)) return llvm::None; - } - - bool failed_once = false; - for (auto& inst : bb) { - if (inst.isKnownTerminator()) break; - std::vector intermediates; - // Build intermediate tensors for tfl.lstm and insert these tensors into - // flatbuffer. - if (llvm::isa(inst)) { - std::vector intermediate_names = { - "input_to_input_intermediate", "input_to_forget_intermediate", - "input_to_cell_intermediate", "input_to_output_intermediate", - "effective_hidden_scale_intermediate"}; - for (const std::string& intermediate : intermediate_names) { - auto intermediate_attr = inst.getAttr(intermediate); - if (auto attr = intermediate_attr.dyn_cast_or_null()) { - Type qtype = attr.getValue(); - auto tensor_or = BuildTensorFromType( - qtype, name_mapper_.GetUniqueName(intermediate).str()); - if (!tensor_or.hasValue()) { - continue; - } else { - intermediates.push_back(tensors.size()); - tensors.push_back(tensor_or.getValue()); - } - } - } - } - - for (auto val : inst.getResults()) { - std::string name = UniqueName(val); - if (!build_tensor_and_buffer(val, name)) return llvm::None; - } - - // Skip constant ops as they don't represent a TFLite operator. - if (IsConst(&inst)) continue; - - // Fetch operand and result tensor indices. - std::vector operands; - operands.reserve(inst.getNumOperands()); - for (auto operand : inst.getOperands()) { - if (operand.getType().isa()) - operands.push_back(kTfLiteOptionalTensor); - else - operands.push_back(tensor_index_map.lookup(operand)); - } - std::vector results; - results.reserve(inst.getNumOperands()); - for (auto result : inst.getResults()) { - results.push_back(tensor_index_map.lookup(result)); - } - - if (auto tfl_operator = - BuildOperator(&inst, operands, results, intermediates)) - operators.push_back(*tfl_operator); - else - failed_once = true; - } - - if (failed_once) return llvm::None; - - // Get input and output tensor indices for the subgraph. - std::vector inputs, outputs; - for (auto arg : bb.getArguments()) { - inputs.push_back(tensor_index_map[arg]); - } - for (auto result : bb.getTerminator()->getOperands()) { - outputs.push_back(tensor_index_map[result]); - } - - return tflite::CreateSubGraph( - builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs), - builder_.CreateVector(outputs), builder_.CreateVector(operators), - /*name=*/builder_.CreateString(name)); -} - -BufferOffset Translator::BuildMetadata(StringRef name, - StringRef content) { - auto buffer_index = buffers_.size(); - auto buffer_data = builder_.CreateVector( - reinterpret_cast(content.data()), content.size()); - buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data)); - return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index); -} - -Optional>> -Translator::CreateMetadataVector() { - auto dict_attr = module_.getAttrOfType("tfl.metadata"); - std::vector> metadata; - if (dict_attr) { - for (const auto& named_attr : dict_attr) { - StringRef name = named_attr.first; - mlir::Attribute attr = named_attr.second; - if (auto content = attr.dyn_cast()) { - metadata.push_back(BuildMetadata(name, content.getValue())); - } else { - module_.emitError( - "all values in tfl.metadata's dictionary key-value pairs should be " - "string attributes"); - return llvm::None; - } - } - } - // Runtime version string is generated after we update the op - // versions. Here we put a 16-byte dummy string as a placeholder. We choose - // 16-byte because it's the alignment of buffers in flatbuffer, so it won't - // cause any waste of space if the actual string is shorter than 16 bytes. - metadata.push_back( - BuildMetadata("min_runtime_version", std::string(16, '\0'))); - return builder_.CreateVector(metadata); -} - -bool UpdateEntryFunction(ModuleOp module) { - if (module.lookupSymbol("main") != nullptr) { - // We already have an entry function. - return true; - } - - int entry_func_count = 0; - FuncOp entry_func = nullptr; - for (auto fn : module.getOps()) { - auto attrs = fn.getAttrOfType("tf.entry_function"); - if (attrs && !attrs.empty()) { - entry_func_count++; - entry_func = fn; - } - } - - // We should have one & only have one entry function. - if (entry_func_count != 1) return false; - - // Update the entry func to main. - entry_func.setName("main"); - return true; -} - -Optional Translator::Translate( - ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { - if (!UpdateEntryFunction(module)) return llvm::None; - if (!IsValidTFLiteMlirModule(module)) return llvm::None; - Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); - return translator.TranslateInternal(); -} - -Optional Translator::TranslateInternal() { - // A list of named regions in the module with main function being the first in - // the list. The main function is required as the first subgraph in the model - // is entry point for the model. - std::vector> named_regions; - named_regions.reserve(std::distance(module_.begin(), module_.end())); - - int subgraph_idx = 0; - FuncOp main_fn = module_.lookupSymbol("main"); - subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++; - named_regions.emplace_back("main", &main_fn.getBody()); - // Walk over the module collection ops with functions and while ops. - module_.walk([&](FuncOp fn) { - if (fn != main_fn) { - subgraph_index_map_[fn.getName().str()] = subgraph_idx++; - named_regions.emplace_back(fn.getName().str(), &fn.getBody()); - } - }); - - // Build subgraph for each of the named regions. - std::vector> subgraphs; - subgraphs.reserve(named_regions.size()); - int first_failed_func = -1; - for (auto it : llvm::enumerate(named_regions)) { - auto subgraph_or = BuildSubGraph(it.value().first, it.value().second); - if (!subgraph_or) { - if (first_failed_func == -1) - // Record the index of the first region that cannot be converted. - // Keep looping through all subgraphs in the module to make sure that - // we collect the list of missing ops from the entire module. - first_failed_func = it.index(); - } else { - subgraphs.push_back(*subgraph_or); - } - } - - if (first_failed_func != -1) { - std::string failed_flex_ops_list = absl::StrJoin(failed_flex_ops_, "\n\t"); - std::string failed_custom_ops_list = - absl::StrJoin(failed_custom_ops_, "\n\t"); - std::string err; - if (!failed_flex_ops_list.empty()) - err += - "Ops that can be supported by the flex runtime (enabled via setting " - "the -emit-select-tf-ops flag):\n\t" + - failed_flex_ops_list; - if (!failed_custom_ops_list.empty()) - err += - "Ops that need custom implementation (enabled via setting the " - "-emit-custom-ops flag):\n\t" + - failed_custom_ops_list; - - auto& failed_region = named_regions[first_failed_func]; - return failed_region.second->getParentOp()->emitError() - << "failed while converting: '" << failed_region.first - << "': " << err, - llvm::None; - } - - std::string model_description; - if (auto attr = module_.getAttrOfType("tfl.description")) { - model_description = attr.getValue().str(); - } else { - model_description = "MLIR Converted."; - } - - // Build the model and finish the model building process. - auto description = builder_.CreateString(model_description.data()); - VectorBufferOffset metadata_buffer = 0; // Deprecated - auto metadata = CreateMetadataVector(); - if (!metadata) return llvm::None; - - auto model = tflite::CreateModel( - builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), - builder_.CreateVector(subgraphs), description, - builder_.CreateVector(buffers_), metadata_buffer, *metadata); - tflite::FinishModelBuffer(builder_, model); - tflite::UpdateOpVersion(builder_.GetBufferPointer()); - tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); - - // Return serialized string for the built FlatBuffer. - return std::string(reinterpret_cast(builder_.GetBufferPointer()), - builder_.GetSize()); -} - -} // namespace - -// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer -// format. Returns false on success. -// -// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting -// the following: -// -// * Quantization -// * Ops with variable tensors -// -bool tflite::MlirToFlatBufferTranslateFunction( - ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - OpOrArgNameMapper* op_or_arg_name_mapper) { - auto maybe_translated = - Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, op_or_arg_name_mapper); - if (!maybe_translated) return true; - *serialized_flatbuffer = std::move(*maybe_translated); - return false; -} - -bool tflite::MlirToFlatBufferTranslateFunction( - ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops) { - OpOrArgLocNameMapper op_or_arg_name_mapper; - return MlirToFlatBufferTranslateFunction( - module, serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper); -} diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 1eec402d35a..4b888764053 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -63,16 +63,20 @@ limitations under the License. #include "mlir/Support/LLVM.h" // TF:llvm-project #include "mlir/Translation.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -96,6 +100,45 @@ using xla::StatusOr; namespace errors = tensorflow::errors; namespace tfl = mlir::TFL; +using llvm::cl::opt; + +// Commandline flag to enable the control of flatbuffer import. +bool use_external_constant; + +// Commandline flag to enable graph pruning. +bool experimental_prune_unreachable_nodes_unconditionally; + +// NOLINTNEXTLINE +static opt use_external_constant_flag( + "use-external-constant", + llvm::cl::desc("Use external constant during flatbuffer import"), + llvm::cl::location(use_external_constant), llvm::cl::init(false)); + +// TODO(b/147111261): After the importer supports generic custom ops, we should +// change the flag to a more lightwise flag, e.g. +// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune +// the operations. +// NOLINTNEXTLINE +static opt experimental_prune_unreachable_nodes_unconditionally_flg( + "experimental-prune-unreachable-nodes-unconditionally", + llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."), + llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static opt input_arrays_flag( + "input-arrays", + llvm::cl::desc( + "List of input tensors, if different from the default inputs"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +static opt output_arrays_flag( + "output-arrays", + llvm::cl::desc( + "List of output tensors, if different from the default outputs"), + llvm::cl::init("")); + namespace { bool IsScalar(const TensorT& tensor) { // TODO(b/138222071) We can't distinguish scalars and unranked tensors @@ -1020,3 +1063,42 @@ OwningModuleRef tflite::FlatBufferToMlir( return OwningModuleRef(module); } + +static OwningModuleRef FlatBufferFileToMlirTrans( + llvm::SourceMgr* source_mgr, MLIRContext* context, + bool use_external_constant, + bool experimental_prune_unreachable_nodes_unconditionally) { + const llvm::MemoryBuffer* input = + source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); + std::string error; + auto loc = + mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context); + + // Parses input/output names from command line options. + std::vector inputs; + std::vector outputs; + // Use output parser since we only have tensor names. + if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) { + return emitError(loc, "parsing input array info failed ") + << input_arrays_flag, + nullptr; + } + if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) { + return emitError(loc, "parsing output array info failed ") + << output_arrays_flag, + nullptr; + } + + return tflite::FlatBufferToMlir( + absl::string_view(input->getBufferStart(), input->getBufferSize()), + context, loc, use_external_constant, inputs, outputs, + experimental_prune_unreachable_nodes_unconditionally); +} + +static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( + "tflite-flatbuffer-to-mlir", + [](llvm::SourceMgr& source_mgr, MLIRContext* context) { + return FlatBufferFileToMlirTrans( + &source_mgr, context, use_external_constant, + experimental_prune_unreachable_nodes_unconditionally); + }); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index ee7ac81dce9..e8337d4a79f 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -13,6 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" @@ -31,48 +56,67 @@ limitations under the License. #include "mlir/IR/Value.h" // TF:llvm-project #include "mlir/Support/LogicalResult.h" // TF:llvm-project #include "mlir/Translation.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" +#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/tools/versioning/op_version.h" +#include "tensorflow/lite/tools/versioning/runtime_version.h" +#include "tensorflow/lite/version.h" -using llvm::cl::opt; +using llvm::dyn_cast; +using llvm::formatv; +using llvm::isa; +using llvm::Optional; +using llvm::StringRef; +using llvm::Twine; +using mlir::Dialect; +using mlir::ElementsAttr; +using mlir::FuncOp; +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::NoneType; +using mlir::Operation; +using mlir::Region; +using mlir::StringAttr; +using mlir::TensorType; +using mlir::TranslateFromMLIRRegistration; +using mlir::Type; +using mlir::UnknownLoc; +using mlir::Value; +using tensorflow::OpOrArgLocNameMapper; +using tensorflow::OpOrArgNameMapper; +using tensorflow::Status; +using tflite::flex::IsWhitelistedFlexOp; +using xla::StatusOr; -// Commandline flag to enable the control of flatbuffer import. -bool use_external_constant; +template +using BufferOffset = flatbuffers::Offset; -// Commandline flag to enable graph pruning. -bool experimental_prune_unreachable_nodes_unconditionally; +template +using VectorBufferOffset = flatbuffers::Offset>; -// NOLINTNEXTLINE -static opt use_external_constant_flag( - "use-external-constant", - llvm::cl::desc("Use external constant during flatbuffer import"), - llvm::cl::location(use_external_constant), llvm::cl::init(false)); +using CustomOptionsOffset = VectorBufferOffset; -// TODO(b/147111261): After the importer supports generic custom ops, we should -// change the flag to a more lightwise flag, e.g. -// "import_custom_ops_as_side_effect_free_ops", and let the MLIR DCE to prune -// the operations. -// NOLINTNEXTLINE -static opt experimental_prune_unreachable_nodes_unconditionally_flg( - "experimental-prune-unreachable-nodes-unconditionally", - llvm::cl::desc("Prune nodes that are not ancestors of the output nodes."), - llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally), - llvm::cl::init(false)); +namespace error = tensorflow::error; +namespace tfl = mlir::TFL; -// NOLINTNEXTLINE -static opt input_arrays_flag( - "input-arrays", - llvm::cl::desc( - "List of input tensors, if different from the default inputs"), - llvm::cl::init("")); - -// NOLINTNEXTLINE -static opt output_arrays_flag( - "output-arrays", - llvm::cl::desc( - "List of output tensors, if different from the default outputs"), - llvm::cl::init("")); using llvm::cl::opt; // These command line flags enable control of the translation implementation. @@ -113,48 +157,1353 @@ static opt strip_debug_info_flag( "strip-debug-info", llvm::cl::desc("Strip debug info during export"), llvm::cl::location(strip_debug_info), llvm::cl::init(false)); -namespace mlir { -namespace { -static OwningModuleRef FlatBufferFileToMlirTrans( - llvm::SourceMgr* source_mgr, MLIRContext* context, - bool use_external_constant, - bool experimental_prune_unreachable_nodes_unconditionally) { - const llvm::MemoryBuffer* input = - source_mgr->getMemoryBuffer(source_mgr->getMainFileID()); - std::string error; - auto loc = - mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context); +ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; - // Parses input/output names from command line options. - std::vector inputs; - std::vector outputs; - // Use output parser since we only have tensor names. - if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) { - return emitError(loc, "parsing input array info failed ") - << input_arrays_flag, - nullptr; +// Use initial buffer size in flatbuffer builder to be same as the initial size +// used by the TOCO export. (It does not explain rationale for this choice.) +constexpr size_t kInitialBufferSize = 10240; + +// Set `isSigned` to false if the `type` is an 8-bit unsigned integer type. +// Since tflite doesn't support unsigned for other types, returns error if +// `isSigned` is set to false for other types. +static StatusOr GetTFLiteType(Type type, + bool is_signed = true) { + if (!is_signed && type.isSignlessInteger(8)) { + return tflite::TensorType_UINT8; } - if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) { - return emitError(loc, "parsing output array info failed ") - << output_arrays_flag, - nullptr; + if (!is_signed) { + return Status(error::INVALID_ARGUMENT, + "'isSigned' can only be set for 8-bits integer type"); + } + switch (type.getKind()) { + case mlir::StandardTypes::F32: + return tflite::TensorType_FLOAT32; + case mlir::StandardTypes::F16: + return tflite::TensorType_FLOAT16; + case mlir::TF::TensorFlowTypes::STRING: + return tflite::TensorType_STRING; + case mlir::TF::TensorFlowTypes::QUINT8: + return tflite::TensorType_UINT8; + case mlir::StandardTypes::Complex: { + auto ftype = type.cast().getElementType(); + if (ftype && ftype.isF32()) { + return tflite::TensorType_COMPLEX64; + } + return Status(error::INVALID_ARGUMENT, "Unsupported type"); + } + case mlir::StandardTypes::Integer: { + const auto& itype = type.cast(); + switch (itype.getWidth()) { + case 1: + return tflite::TensorType_BOOL; + case 8: + return itype.isUnsigned() ? tflite::TensorType_UINT8 + : tflite::TensorType_INT8; + case 16: + return tflite::TensorType_INT16; + case 32: + return tflite::TensorType_INT32; + case 64: + return tflite::TensorType_INT64; + } + } + case mlir::quant::QuantizationTypes::UniformQuantized: { + auto qtype = type.cast(); + return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); + } + case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: { + auto qtype = type.cast(); + return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); + } + case mlir::TF::TensorFlowTypes::RESOURCE: { + // Treat tf.resource values as integer values in flatbuffer. + // TODO(b/146131919): Maybe need to have a detailed design for supporting + // other resource types beyonds hash table resources and resource + // variables. + return tflite::TensorType_INT32; + } + default: + // TFLite export fills FLOAT32 for unknown data types. Returning an error + // for now for safety and this could be revisited when required. + return Status(error::INVALID_ARGUMENT, "Unsupported type"); } - return tflite::FlatBufferToMlir( - absl::string_view(input->getBufferStart(), input->getBufferSize()), - context, loc, use_external_constant, inputs, outputs, - experimental_prune_unreachable_nodes_unconditionally); } -static LogicalResult MlirToFlatBufferFileTranslateFunction( +static bool IsConst(Operation* op) { + return isa(op) || isa(op) || + isa(op) || isa(op); +} + +template +static bool HasValidTFLiteType(Value value, T& error_handler) { + // None type is allowed to represent unspecified operands. + if (value.getType().isa()) return true; + + auto type = value.getType().dyn_cast(); + if (!type) { + if (auto op = value.getDefiningOp()) { + error_handler.emitError() + << '\'' << op << "' should produce value of tensor type instead of " + << value.getType(); + return false; + } + error_handler.emitError("expected tensor type, got ") << value.getType(); + return false; + } + + Type element_type = type.getElementType(); + auto status = GetTFLiteType(element_type); + if (!status.ok()) { + return error_handler.emitError( + formatv("Failed to convert element type '{0}': {1}", + element_type, status.status().error_message())), + false; + } + return true; +} + +// Returns true if the module holds all the invariants expected by the +// Translator class. +// TODO(hinsu): Now that translation is done by making a single pass over the +// MLIR module, consider inlining these validation checks at the place where +// these invariants are assumed instead of checking upfront. +static bool IsValidTFLiteMlirModule(ModuleOp module) { + MLIRContext* context = module.getContext(); + + // Verify that module has a function named main. + FuncOp main_fn = module.lookupSymbol("main"); + if (!main_fn) { + return emitError(UnknownLoc::get(context), + "should have a function named 'main'"), + false; + } + + for (auto fn : module.getOps()) { + if (fn.getBlocks().size() != 1) { + return fn.emitError("should have exactly one basic block"), false; + } + auto& bb = fn.getBlocks().front(); + + for (auto arg : bb.getArguments()) { + if (!HasValidTFLiteType(arg, fn)) + return fn.emitError("invalid TFLite type: ") << arg.getType(), false; + } + + // Verify that all operations except the terminator have exactly one + // result of type supported by TFLite. + for (auto& inst : bb) { + if (inst.isKnownTerminator()) break; + + for (auto result : inst.getResults()) { + if (!HasValidTFLiteType(result, inst)) + return fn.emitError("invalid TFLite type: ") << result.getType(), + false; + } + } + } + + return true; +} + +static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef( + ::mlir::Operation* inst) { + // We pass empty string for the original node_def name since Flex runtime + // does not care about this being set correctly on node_def. There is no + // "easy" (see b/120948529) way yet to get this from MLIR inst. + auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef( + inst, /*name=*/"", /*ignore_unregistered_attrs=*/true); + if (!status_or_node_def.ok()) { + inst->emitOpError( + Twine("failed to obtain TensorFlow nodedef with status: " + + status_or_node_def.status().ToString())); + return {}; + } + return std::move(status_or_node_def.ValueOrDie()); +} + +// Converts a mlir padding StringRef to TfLitePadding. +// Returns llvm::None if conversion fails. +static Optional GetTflitePadding(Operation* inst, + llvm::StringRef padding) { + const tflite::Padding padding_attr = + std::move(llvm::StringSwitch(padding) + .Case("SAME", tflite::Padding_SAME) + .Case("VALID", tflite::Padding_VALID)); + if (padding_attr == tflite::Padding_SAME) { + return kTfLitePaddingSame; + } + if (padding_attr == tflite::Padding_VALID) { + return kTfLitePaddingValid; + } + + return inst->emitOpError() << "Invalid padding attribute: " << padding, + llvm::None; +} + +// Extracts TfLitePoolParams from a TFL custom op. +// Template parameter, TFLOp, should be a TFL custom op containing attributes +// generated from TfLitePoolParams. +// Returns llvm::None if conversion fails. +template +static Optional GetTflitePoolParams(Operation* inst, + TFLOp op) { + TfLitePoolParams pool_params; + pool_params.stride_height = op.stride_h().getSExtValue(); + pool_params.stride_width = op.stride_w().getSExtValue(); + pool_params.filter_height = op.filter_h().getSExtValue(); + pool_params.filter_width = op.filter_w().getSExtValue(); + const auto padding = GetTflitePadding(inst, op.padding()); + if (padding) { + pool_params.padding = *padding; + pool_params.activation = kTfLiteActNone; + pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0}; + return pool_params; + } + + return llvm::None; +} + +namespace { + +// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. +class Translator { + public: + // Translates the given MLIR module into TFLite FlatBuffer format and returns + // the serialized output. Returns llvm::None on unsupported, invalid inputs or + // internal error. + static Optional Translate( + ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper); + + private: + enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; + explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, + bool emit_select_tf_ops, bool emit_custom_ops, + OpOrArgNameMapper* op_or_arg_name_mapper) + : module_(module), + name_mapper_(*op_or_arg_name_mapper), + builder_(kInitialBufferSize) { + // The first buffer must be empty according to the schema definition. + empty_buffer_ = tflite::CreateBuffer(builder_); + buffers_.push_back(empty_buffer_); + if (emit_builtin_tflite_ops) { + enabled_op_types_.emplace(OpType::kTfliteBuiltin); + } + if (emit_select_tf_ops) { + enabled_op_types_.emplace(OpType::kSelectTf); + } + if (emit_custom_ops) { + enabled_op_types_.emplace(OpType::kCustomOp); + } + tf_dialect_ = module.getContext()->getRegisteredDialect("tf"); + tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl"); + } + + Optional TranslateInternal(); + + // Returns TFLite buffer populated with constant value if the operation is + // TFLite constant operation. Otherwise, returns an empty buffer. Emits error + // and returns llvm::None on failure. + Optional> BuildBuffer(Operation* inst); + + // Build TFLite tensor from the given type. This function is for tfl.lstm + // intermediates, which should have UniformQuantizedType. + Optional> BuildTensorFromType( + mlir::Type type, const std::string& name); + + // Builds TFLite tensor from the given value. `buffer_idx` is index of the + // corresponding buffer. Emits error and returns llvm::None on failure. + Optional> BuildTensor(Value value, + const std::string& name, + unsigned buffer_idx); + + // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove + // these 2 functions here. + BufferOffset BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results); + BufferOffset BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results); + + // Build while operator where cond & body are regions. + Optional> BuildWhileOperator( + mlir::TFL::WhileOp op, const std::vector& operands, + const std::vector& results); + + // Builds custom operators. + // Templated on a) data type of custom_option to be stored into flatbuffer, + // and b) TFL custom op type. + template + BufferOffset BuildCustomOperator( + const CustomOptionType& custom_option, const std::string& opcode_name, + TFLOp op, const std::vector& operands, + const std::vector& results); + + BufferOffset BuildNumericVerifyOperator( + mlir::TFL::NumericVerifyOp op, const std::vector& operands, + const std::vector& results); + Optional> + BuildConvolution2DTransposeBiasOperator( + Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, + const std::vector& operands, + const std::vector& results); + Optional> BuildMaxPoolingWithArgMax2DOperator( + Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, + const std::vector& operands, + const std::vector& results); + Optional> BuildMaxUnpooling2DOperator( + Operation* inst, mlir::TFL::MaxUnpooling2DOp op, + const std::vector& operands, + const std::vector& results); + + Optional CreateFlexOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + Optional CreateCustomOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + std::unique_ptr CreateFlexBuilderWithNodeAttrs( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); + + // Returns opcode index for op identified by the op_name, if already + // available. Otherwise, creates a new OperatorCode using the given `builtin` + // operator and associates it with `op_name`. + uint32_t GetOpcodeIndex(const std::string& op_name, + tflite::BuiltinOperator builtin); + + // Builds operator for the given operation with specified operand and result + // tensor indices. Emits an error and returns llvm::None on failure. + Optional> BuildOperator( + Operation* inst, const std::vector& operands, + const std::vector& results, + const std::vector& intermediates); + + // Build a subgraph with a given name out of the region either corresponding + // to a function's body or while op. + Optional> BuildSubGraph( + const std::string& name, Region* region); + + // Builds Metadata with the given `name` and buffer `content`. + BufferOffset BuildMetadata(StringRef name, + StringRef content); + + // Encodes the `tfl.metadata` dictionary attribute of the module to the + // metadata section in the final model. + Optional>> + CreateMetadataVector(); + + // Uses the tf.entry_function attribute (if set) to initialize the op to name + // mapping. + void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); + + // Determines if the specified operation op's operand at operand_index + // is marked as a stateful operand. + bool IsStatefulOperand(mlir::Operation* op, int operand_index); + + // Returns a unique name for `val`. + std::string UniqueName(mlir::Value val); + + ModuleOp module_; + + tensorflow::OpOrArgNameMapper& name_mapper_; + + flatbuffers::FlatBufferBuilder builder_; + BufferOffset empty_buffer_; + + std::vector> buffers_; + + // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. + absl::flat_hash_map opcode_index_map_; + std::vector> opcodes_; + + // Maps function name to index of the corresponding subgraph in the FlatBuffer + // model. + absl::flat_hash_map subgraph_index_map_; + absl::flat_hash_set enabled_op_types_; + + // Points to TensorFlow and TFLite dialects, respectively. nullptr if the + // dialect is not registered. + const Dialect* tf_dialect_; + const Dialect* tfl_dialect_; + + // The failed ops during legalization. + std::set failed_flex_ops_; + std::set failed_custom_ops_; +}; + +std::string Translator::UniqueName(mlir::Value val) { + return std::string(name_mapper_.GetUniqueName(val)); +} + +Optional> Translator::BuildBuffer( + Operation* inst) { + ElementsAttr attr; + if (auto cst = dyn_cast(inst)) { + // ConstantOp have ElementAttr at this point due to validation of the TFLite + // module. + attr = cst.getValue().cast(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + } else { + return empty_buffer_; + } + + tensorflow::Tensor tensor; + auto status = tensorflow::ConvertToTensor(attr, &tensor); + if (!status.ok()) { + inst->emitError( + Twine("failed to convert value attribute to tensor with error: " + + status.ToString())); + return llvm::None; + } + + // TensorFlow and TensorFlow Lite use different string encoding formats. + // Convert to TensorFlow Lite format is it's a constant string tensor. + if (tensor.dtype() == tensorflow::DT_STRING) { + ::tflite::DynamicBuffer dynamic_buffer; + auto flat = tensor.flat<::tensorflow::tstring>(); + for (int i = 0; i < flat.size(); ++i) { + const auto& str = flat(i); + dynamic_buffer.AddString(str.c_str(), str.length()); + } + char* tensor_buffer; + int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer); + auto buffer_data = + builder_.CreateVector(reinterpret_cast(tensor_buffer), bytes); + free(tensor_buffer); + return tflite::CreateBuffer(builder_, buffer_data); + } + + absl::string_view tensor_data = tensor.tensor_data(); + auto buffer_data = builder_.CreateVector( + reinterpret_cast(tensor_data.data()), tensor_data.size()); + return tflite::CreateBuffer(builder_, buffer_data); +} + +Optional> Translator::BuildTensorFromType( + mlir::Type type, const std::string& name) { + auto tensor_type = type.cast(); + + if (!tensor_type.hasStaticShape()) { + return llvm::None; + } + llvm::ArrayRef shape_ref = tensor_type.getShape(); + std::vector shape(shape_ref.begin(), shape_ref.end()); + + auto element_type = tensor_type.getElementType(); + tflite::TensorType tflite_element_type = + GetTFLiteType(tensor_type.getElementType()).ValueOrDie(); + BufferOffset q_params; + auto qtype = element_type.dyn_cast(); + if (!qtype) { + return llvm::None; + } + q_params = tflite::CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, + builder_.CreateVector({static_cast(qtype.getScale())}), + builder_.CreateVector({qtype.getZeroPoint()})); + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + /*buffer=*/0, builder_.CreateString(name), q_params, + /*is_variable=*/false); +} + +Optional> Translator::BuildTensor( + Value value, const std::string& name, unsigned buffer_idx) { + auto type = value.getType().cast(); + + // TFLite requires tensor shape only for the inputs and constants. + // However, we output all known shapes for better round-tripping + auto check_shape = + [&](llvm::ArrayRef shape_ref) -> mlir::LogicalResult { + auto is_out_of_range = [](int64_t dim) { + return dim > std::numeric_limits::max(); + }; + + if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) + return mlir::emitError( + value.getLoc(), + "result shape dimensions out of 32 bit int type range"); + + return mlir::success(); + }; + + std::vector shape; + std::vector shape_signature; + if (type.hasStaticShape()) { + llvm::ArrayRef shape_ref = type.getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape = std::vector(shape_ref.begin(), shape_ref.end()); + } else if (auto* inst = value.getDefiningOp()) { + if (IsConst(inst)) { + // Const op can have a result of dynamic shaped type (e.g. due to constant + // folding), but we can still derive the shape of a constant tensor for + // its attribute type. + mlir::Attribute tensor_attr = inst->getAttr("value"); + llvm::ArrayRef shape_ref = + tensor_attr.getType().cast().getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape = std::vector(shape_ref.begin(), shape_ref.end()); + } + } else if (type.hasRank()) { + llvm::ArrayRef shape_ref = type.getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape.reserve(shape_ref.size()); + for (auto& dim : shape_ref) { + shape.push_back(dim == -1 ? 1 : dim); + } + shape_signature = std::vector(shape_ref.begin(), shape_ref.end()); + } + + if (auto* inst = value.getDefiningOp()) { + if (auto cst = dyn_cast(inst)) { + // CreateSparsityParameters(cst.s_param()); + } else if (auto cst = dyn_cast(inst)) { + // CreateSparsityParameters(cst.s_param()); + } + } + + Type element_type = type.getElementType(); + tflite::TensorType tflite_element_type = + GetTFLiteType(type.getElementType()).ValueOrDie(); + + BufferOffset q_params; + if (auto qtype = element_type.dyn_cast()) { + q_params = tflite::CreateQuantizationParameters( + // TODO(fengliuai): min and max values are not stored in the + // quantized type, so both are set to 0. The model couldn't be imported + // to TensorFlow because of this. + builder_, /*min=*/0, /*max=*/0, + builder_.CreateVector({static_cast(qtype.getScale())}), + builder_.CreateVector({qtype.getZeroPoint()})); + } else if (auto qtype = + element_type + .dyn_cast()) { + std::vector scales(qtype.getScales().begin(), + qtype.getScales().end()); + q_params = tflite::CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), + builder_.CreateVector(qtype.getZeroPoints()), + tflite::QuantizationDetails_NONE, /*details=*/0, + qtype.getQuantizedDimension()); + } else { + q_params = tflite::CreateQuantizationParameters(builder_); + } + // Check if the value's uses includes an op and usage at an operand index + // marked as a stateful. If so, set the tensor's is_variable as true + // This is v1 ref variable semantics in the TFLite runtime. + bool is_variable = false; + for (auto& use : value.getUses()) { + is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); + if (is_variable) { + break; + } + } + + if (shape_signature.empty()) { + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, + /*is_variable=*/is_variable); + } else { + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, + /*is_variable=*/is_variable, /*sparsity=*/0, + /*shape_signature=*/builder_.CreateVector(shape_signature)); + } +} + +BufferOffset Translator::BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); + int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); + int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); + auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, + else_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_IfOptions, + builtin_options); +} + +BufferOffset Translator::BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); + int cond_subgraph_index = subgraph_index_map_.at(op.cond().str()); + int body_subgraph_index = subgraph_index_map_.at(op.body().str()); + auto builtin_options = tflite::CreateWhileOptions( + builder_, cond_subgraph_index, body_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_WhileOptions, + builtin_options); +} + +Optional> Translator::BuildWhileOperator( + mlir::TFL::WhileOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); + auto get_call_index = [&](mlir::Block& b) -> Optional { + if (b.getOperations().size() != 2) return llvm::None; + if (auto call_op = dyn_cast(b.front())) + return subgraph_index_map_.at(call_op.callee().str()); + return llvm::None; + }; + auto body_subgraph_index = get_call_index(op.body().front()); + auto cond_subgraph_index = get_call_index(op.cond().front()); + if (!body_subgraph_index || !cond_subgraph_index) + return op.emitOpError("only single call cond/body while export supported"), + llvm::None; + auto builtin_options = + tflite::CreateWhileOptions(builder_, *cond_subgraph_index, + *body_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_WhileOptions, + builtin_options); +} + +template +BufferOffset Translator::BuildCustomOperator( + const CustomOptionType& custom_option, const std::string& opcode_name, + TFLOp op, const std::vector& operands, + const std::vector& results) { + std::vector custom_option_vector(sizeof(CustomOptionType)); + memcpy(custom_option_vector.data(), &custom_option, sizeof(CustomOptionType)); + auto opcode_index = + GetOpcodeIndex(opcode_name, tflite::BuiltinOperator_CUSTOM); + return tflite::CreateOperator( + builder_, opcode_index, builder_.CreateVector(operands), + builder_.CreateVector(results), tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, + builder_.CreateVector(custom_option_vector), + tflite::CustomOptionsFormat_FLEXBUFFERS); +} + +BufferOffset Translator::BuildNumericVerifyOperator( + mlir::TFL::NumericVerifyOp op, const std::vector& operands, + const std::vector& results) { + float tolerance = op.tolerance().convertToFloat(); + return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results); +} + +Optional> +Translator::BuildConvolution2DTransposeBiasOperator( + Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, + const std::vector& operands, const std::vector& results) { + TfLiteTransposeConvParams conv_params; + conv_params.stride_height = op.stride_h().getSExtValue(); + conv_params.stride_width = op.stride_w().getSExtValue(); + const auto padding = GetTflitePadding(inst, op.padding()); + if (padding) { + conv_params.padding = *padding; + return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op, + operands, results); + } + + return llvm::None; +} + +Optional> +Translator::BuildMaxPoolingWithArgMax2DOperator( + Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, + const std::vector& operands, const std::vector& results) { + const auto pool_params = GetTflitePoolParams(inst, op); + if (pool_params) { + return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op, + operands, results); + } + + return llvm::None; +} + +Optional> +Translator::BuildMaxUnpooling2DOperator(Operation* inst, + mlir::TFL::MaxUnpooling2DOp op, + const std::vector& operands, + const std::vector& results) { + const auto pool_params = GetTflitePoolParams(inst, op); + if (pool_params) { + return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands, + results); + } + + return llvm::None; +} + +Optional Translator::CreateFlexOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + std::string node_def_str; + if (!node_def.SerializeToString(&node_def_str)) { + return emitError(loc, "failed to serialize tensorflow node_def"), + llvm::None; + } + + auto flex_builder = absl::make_unique(); + flex_builder->Vector([&]() { + flex_builder->String(node_def.op()); + flex_builder->String(node_def_str); + }); + flex_builder->Finish(); + return builder_.CreateVector(flex_builder->GetBuffer()); +} + +Optional Translator::CreateCustomOpCustomOptions( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + std::string node_def_str; + if (!node_def.SerializeToString(&node_def_str)) { + return emitError(loc, "failed to serialize tensorflow node_def"), + llvm::None; + } + auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); + return builder_.CreateVector(flex_builder->GetBuffer()); +} + +std::unique_ptr +Translator::CreateFlexBuilderWithNodeAttrs( + const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { + auto flex_builder = absl::make_unique(); + size_t map_start = flex_builder->StartMap(); + for (const auto& pair : node_def.attr()) { + const char* key = pair.first.c_str(); + const auto& attr = pair.second; + switch (attr.value_case()) { + case ::tensorflow::AttrValue::kS: + flex_builder->String(key, attr.s()); + break; + case ::tensorflow::AttrValue::kType: { + auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type()); + if (status_or_tfl_type.ok()) { + flex_builder->Int(key, status_or_tfl_type.ValueOrDie()); + } else { + emitWarning(loc, "ignoring unsupported tensorflow type: ") + << std::to_string(attr.type()); + } + break; + } + case ::tensorflow::AttrValue::kI: + flex_builder->Int(key, attr.i()); + break; + case ::tensorflow::AttrValue::kF: + flex_builder->Float(key, attr.f()); + break; + case ::tensorflow::AttrValue::kB: + flex_builder->Bool(key, attr.b()); + break; + case tensorflow::AttrValue::kList: + if (attr.list().s_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const std::string& v : attr.list().s()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else if (attr.list().i_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const int64_t v : attr.list().i()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else if (attr.list().f_size() > 0) { + auto start = flex_builder->StartVector(key); + for (const float v : attr.list().f()) { + flex_builder->Add(v); + } + flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); + } else { + emitWarning(loc, + "ignoring unsupported type in list attribute with key: ") + << key; + } + break; + default: + emitWarning(loc, "ignoring unsupported attribute type with key: ") + << key; + break; + } + } + flex_builder->EndMap(map_start); + flex_builder->Finish(); + return flex_builder; +} + +uint32_t Translator::GetOpcodeIndex(const std::string& op_name, + tflite::BuiltinOperator builtin) { + auto it = opcode_index_map_.insert({op_name, 0}); + + // If the insert succeeded, the opcode has not been created already. Create a + // new operator code and update its index value in the map. + if (it.second) { + it.first->second = opcodes_.size(); + auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM + ? builder_.CreateString(op_name) + : BufferOffset(); + // Use version 0 for builtin op. This is a way to serialize version field to + // flatbuffer (since 0 is non default) and it will be corrected later. + int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1; + opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin, + custom_code, op_version)); + } + return it.first->second; +} + +Optional> Translator::BuildOperator( + Operation* inst, const std::vector& operands, + const std::vector& results, + const std::vector& intermediates) { + const auto* dialect = inst->getDialect(); + if (!dialect) { + inst->emitOpError("dialect is not registered"); + return llvm::None; + } + + // If TFLite built in op, create operator as a builtin op. + if (dialect == tfl_dialect_) { + // Only if built-in TFLite op emission is enabled, would legalization have + // converted any TF->TFL. + if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) { + return inst->emitOpError( + "is a TFLite builtin op but builtin emission is not enabled"), + llvm::None; + } + + auto builtin_code = GetBuiltinOpCode(inst); + if (!builtin_code) { + if (auto verify_op = dyn_cast(inst)) { + return BuildNumericVerifyOperator(verify_op, operands, results); + } + if (auto conv_transpose_bias_op = + dyn_cast(inst)) { + return BuildConvolution2DTransposeBiasOperator( + inst, conv_transpose_bias_op, operands, results); + } + if (auto max_pooling_with_arg_max_op = + dyn_cast(inst)) { + return BuildMaxPoolingWithArgMax2DOperator( + inst, max_pooling_with_arg_max_op, operands, results); + } + if (auto max_unpooling_op = dyn_cast(inst)) { + return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands, + results); + } + if (auto whileOp = dyn_cast(inst)) { + if (inst->getNumOperands() != inst->getNumResults()) { + inst->emitOpError( + "number of operands and results don't match, only canonical " + "TFL While supported"); + return llvm::None; + } + return BuildWhileOperator(whileOp, operands, results); + } + + inst->emitOpError("is not a supported TFLite op"); + return llvm::None; + } + + std::string op_name = inst->getName().getStringRef().str(); + uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code); + auto offset = CreateFlatBufferOperator(inst, opcode_index, operands, + results, intermediates, &builder_); + if (!offset) { + inst->emitOpError("is not a supported TFLite op"); + } + return offset; + } + + if (dialect == tf_dialect_) { + std::string op_name; + if (auto ifOp = dyn_cast(inst)) { + return BuildIfOperator(ifOp, operands, results); + } else if (auto whileOp = dyn_cast(inst)) { + return BuildWhileOperator(whileOp, operands, results); + } + + CustomOptionsOffset custom_options; + + // Ops in TF dialect can either be custom ops or flex ops. + // The reason we go directly from TensorFlow dialect MLIR to tensorflow + // node instead of going to TF table gen'd ops via generated code is that + // we do not want to restrict custom and flex op conversion support to + // only those TF ops that are currently registered in MLIR. The current + // model is of an open op system. + // + // The following algorithm is followed: + // if flex is enabled and the op is whitelisted as flex + // we emit op as flex. + // if custom is enabled + // we emit the op as custom. + auto node_def = GetTensorFlowNodeDef(inst); + if (!node_def) { + return llvm::None; + } + + // Flex op case + // Eventually, the whitelist will go away and we will rely on some TF op + // trait (e.g. No side effect) to determine if it is a supported "Flex" + // op or not. + if (enabled_op_types_.contains(OpType::kSelectTf) && + IsWhitelistedFlexOp(node_def->op())) { + // Construct ops as flex op encoding TensorFlow node definition + // as custom options. + // Flex ops are named with the kFlexOpNamePrefix prefix to the actual + // TF op name. + op_name = std::string(kFlexOpNamePrefix) + node_def->op(); + if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; + } else { + return llvm::None; + } + } else if (enabled_op_types_.contains(OpType::kCustomOp)) { + // Generic case of custom ops - write using flex buffers since that + // is the only custom options supported by TFLite today. + op_name = node_def->op(); + if (auto options = + CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; + } else { + return llvm::None; + } + } else { + // Create description of operation that could not be converted. + const int kLargeElementsAttr = 16; + std::string op_str; + llvm::raw_string_ostream os(op_str); + inst->getName().print(os); + // Print out attributes except for large elementsattributes (which should + // rarely be the cause why the legalization didn't happen). + if (!inst->getAttrList().getAttrs().empty()) { + os << " {"; + bool first = true; + for (auto& named_attr : inst->getAttrList().getDictionary()) { + os << (!first ? ", " : ""); + first = false; + named_attr.first.print(os); + os << " = "; + if (auto element_attr = named_attr.second.dyn_cast()) { + if (element_attr.getNumElements() <= kLargeElementsAttr) { + element_attr.print(os); + } else { + os << ""; + } + } else { + named_attr.second.print(os); + } + } + os << "}"; + } + + // Insert failed op to `flex_ops` or `custom_ops`. + if (IsWhitelistedFlexOp(node_def->op())) { + failed_flex_ops_.insert(os.str()); + } else { + failed_custom_ops_.insert(os.str()); + } + return inst->emitOpError("is neither a custom op nor a flex op"), + llvm::None; + } + + uint32_t opcode_index = + GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, + /*custom_options=*/custom_options, + tflite::CustomOptionsFormat_FLEXBUFFERS, + /*mutating_variable_inputs=*/0); + } + + return inst->emitOpError( + "is not any of a builtin TFLite op, a flex TensorFlow op or a " + "custom TensorFlow op"), + llvm::None; +} + +void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { + auto dict_attr = fn.getAttrOfType("tf.entry_function"); + if (!dict_attr) return; + + llvm::SmallVector input_names; + llvm::SmallVector output_names; + if (auto str = dict_attr.get("inputs").dyn_cast_or_null()) { + str.getValue().split(input_names, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + if (input_names.size() != fn.getNumArguments()) { + fn.emitWarning() << "invalid entry function specification"; + return; + } + for (auto it : llvm::enumerate(fn.getArguments())) { + name_mapper_.InitOpName(it.value(), input_names[it.index()].trim()); + } + *has_input_attr = true; + } + + if (auto str = + dict_attr.get("outputs").dyn_cast_or_null()) { + str.getValue().split(output_names, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + auto term = fn.getBlocks().back().getTerminator(); + if (output_names.size() != term->getNumOperands()) { + fn.emitWarning() << "output names (" << output_names.size() + << ") != terminator operands (" << term->getNumOperands() + << ")"; + return; + } + for (const auto& it : llvm::enumerate(term->getOperands())) { + name_mapper_.InitOpName(it.value(), output_names[it.index()].trim()); + } + } +} + +bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { + std::vector operand_indices; + if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false; + return absl::c_find(operand_indices, operand_index) != operand_indices.end(); +} + +Optional> Translator::BuildSubGraph( + const std::string& name, Region* region) { + bool has_input_attr = false; + if (auto fn = dyn_cast(region->getParentOp())) { + InitializeNamesFromAttribute(fn, &has_input_attr); + } + std::vector> tensors; + llvm::DenseMap tensor_index_map; + + // Builds tensor and buffer for argument or operation result. Returns false + // on failure. + auto build_tensor_and_buffer = [&](Value value, const std::string& name) { + // NoneType represents optional and may be skipped here. + if (value.getType().isa()) { + return true; + } + + tensor_index_map.insert({value, tensors.size()}); + auto tensor_or = BuildTensor(value, name, buffers_.size()); + if (!tensor_or) return false; + tensors.push_back(*tensor_or); + + // TODO(ashwinm): Check if for stateful tensors, if it is also needed to + // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. + // This does not seem to affect runtime behavior for RNN/LSTM, but would be + // good for reducing memory footprint. + if (auto* inst = value.getDefiningOp()) { + auto buffer_or = BuildBuffer(inst); + if (!buffer_or) return false; + buffers_.push_back(*buffer_or); + } else { + buffers_.push_back(empty_buffer_); + } + return true; + }; + + std::vector> operators; + auto& bb = region->front(); + + // Main function's arguments are first passed to `input` op so they don't + // have associated tensor and buffer. Build FlatBuffer tensor and buffer for + // other functions. + for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { + mlir::BlockArgument arg = bb.getArgument(i); + std::string name; + if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg)); + if (name.empty()) name = absl::StrCat("arg", i); + if (!build_tensor_and_buffer(arg, name)) return llvm::None; + } + + bool failed_once = false; + for (auto& inst : bb) { + if (inst.isKnownTerminator()) break; + std::vector intermediates; + // Build intermediate tensors for tfl.lstm and insert these tensors into + // flatbuffer. + if (llvm::isa(inst)) { + std::vector intermediate_names = { + "input_to_input_intermediate", "input_to_forget_intermediate", + "input_to_cell_intermediate", "input_to_output_intermediate", + "effective_hidden_scale_intermediate"}; + for (const std::string& intermediate : intermediate_names) { + auto intermediate_attr = inst.getAttr(intermediate); + if (auto attr = intermediate_attr.dyn_cast_or_null()) { + Type qtype = attr.getValue(); + auto tensor_or = BuildTensorFromType( + qtype, name_mapper_.GetUniqueName(intermediate).str()); + if (!tensor_or.hasValue()) { + continue; + } else { + intermediates.push_back(tensors.size()); + tensors.push_back(tensor_or.getValue()); + } + } + } + } + + for (auto val : inst.getResults()) { + std::string name = UniqueName(val); + if (!build_tensor_and_buffer(val, name)) return llvm::None; + } + + // Skip constant ops as they don't represent a TFLite operator. + if (IsConst(&inst)) continue; + + // Fetch operand and result tensor indices. + std::vector operands; + operands.reserve(inst.getNumOperands()); + for (auto operand : inst.getOperands()) { + if (operand.getType().isa()) + operands.push_back(kTfLiteOptionalTensor); + else + operands.push_back(tensor_index_map.lookup(operand)); + } + std::vector results; + results.reserve(inst.getNumOperands()); + for (auto result : inst.getResults()) { + results.push_back(tensor_index_map.lookup(result)); + } + + if (auto tfl_operator = + BuildOperator(&inst, operands, results, intermediates)) + operators.push_back(*tfl_operator); + else + failed_once = true; + } + + if (failed_once) return llvm::None; + + // Get input and output tensor indices for the subgraph. + std::vector inputs, outputs; + for (auto arg : bb.getArguments()) { + inputs.push_back(tensor_index_map[arg]); + } + for (auto result : bb.getTerminator()->getOperands()) { + outputs.push_back(tensor_index_map[result]); + } + + return tflite::CreateSubGraph( + builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs), + builder_.CreateVector(outputs), builder_.CreateVector(operators), + /*name=*/builder_.CreateString(name)); +} + +BufferOffset Translator::BuildMetadata(StringRef name, + StringRef content) { + auto buffer_index = buffers_.size(); + auto buffer_data = builder_.CreateVector( + reinterpret_cast(content.data()), content.size()); + buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data)); + return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index); +} + +Optional>> +Translator::CreateMetadataVector() { + auto dict_attr = module_.getAttrOfType("tfl.metadata"); + std::vector> metadata; + if (dict_attr) { + for (const auto& named_attr : dict_attr) { + StringRef name = named_attr.first; + mlir::Attribute attr = named_attr.second; + if (auto content = attr.dyn_cast()) { + metadata.push_back(BuildMetadata(name, content.getValue())); + } else { + module_.emitError( + "all values in tfl.metadata's dictionary key-value pairs should be " + "string attributes"); + return llvm::None; + } + } + } + // Runtime version string is generated after we update the op + // versions. Here we put a 16-byte dummy string as a placeholder. We choose + // 16-byte because it's the alignment of buffers in flatbuffer, so it won't + // cause any waste of space if the actual string is shorter than 16 bytes. + metadata.push_back( + BuildMetadata("min_runtime_version", std::string(16, '\0'))); + return builder_.CreateVector(metadata); +} + +bool UpdateEntryFunction(ModuleOp module) { + if (module.lookupSymbol("main") != nullptr) { + // We already have an entry function. + return true; + } + + int entry_func_count = 0; + FuncOp entry_func = nullptr; + for (auto fn : module.getOps()) { + auto attrs = fn.getAttrOfType("tf.entry_function"); + if (attrs && !attrs.empty()) { + entry_func_count++; + entry_func = fn; + } + } + + // We should have one & only have one entry function. + if (entry_func_count != 1) return false; + + // Update the entry func to main. + entry_func.setName("main"); + return true; +} + +Optional Translator::Translate( + ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) { + if (!UpdateEntryFunction(module)) return llvm::None; + if (!IsValidTFLiteMlirModule(module)) return llvm::None; + Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, + emit_custom_ops, op_or_arg_name_mapper); + return translator.TranslateInternal(); +} + +Optional Translator::TranslateInternal() { + // A list of named regions in the module with main function being the first in + // the list. The main function is required as the first subgraph in the model + // is entry point for the model. + std::vector> named_regions; + named_regions.reserve(std::distance(module_.begin(), module_.end())); + + int subgraph_idx = 0; + FuncOp main_fn = module_.lookupSymbol("main"); + subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++; + named_regions.emplace_back("main", &main_fn.getBody()); + // Walk over the module collection ops with functions and while ops. + module_.walk([&](FuncOp fn) { + if (fn != main_fn) { + subgraph_index_map_[fn.getName().str()] = subgraph_idx++; + named_regions.emplace_back(fn.getName().str(), &fn.getBody()); + } + }); + + // Build subgraph for each of the named regions. + std::vector> subgraphs; + subgraphs.reserve(named_regions.size()); + int first_failed_func = -1; + for (auto it : llvm::enumerate(named_regions)) { + auto subgraph_or = BuildSubGraph(it.value().first, it.value().second); + if (!subgraph_or) { + if (first_failed_func == -1) + // Record the index of the first region that cannot be converted. + // Keep looping through all subgraphs in the module to make sure that + // we collect the list of missing ops from the entire module. + first_failed_func = it.index(); + } else { + subgraphs.push_back(*subgraph_or); + } + } + + if (first_failed_func != -1) { + std::string failed_flex_ops_list = absl::StrJoin(failed_flex_ops_, "\n\t"); + std::string failed_custom_ops_list = + absl::StrJoin(failed_custom_ops_, "\n\t"); + std::string err; + if (!failed_flex_ops_list.empty()) + err += + "Ops that can be supported by the flex runtime (enabled via setting " + "the -emit-select-tf-ops flag):\n\t" + + failed_flex_ops_list; + if (!failed_custom_ops_list.empty()) + err += + "Ops that need custom implementation (enabled via setting the " + "-emit-custom-ops flag):\n\t" + + failed_custom_ops_list; + + auto& failed_region = named_regions[first_failed_func]; + return failed_region.second->getParentOp()->emitError() + << "failed while converting: '" << failed_region.first + << "': " << err, + llvm::None; + } + + std::string model_description; + if (auto attr = module_.getAttrOfType("tfl.description")) { + model_description = attr.getValue().str(); + } else { + model_description = "MLIR Converted."; + } + + // Build the model and finish the model building process. + auto description = builder_.CreateString(model_description.data()); + VectorBufferOffset metadata_buffer = 0; // Deprecated + auto metadata = CreateMetadataVector(); + if (!metadata) return llvm::None; + + auto model = tflite::CreateModel( + builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), + builder_.CreateVector(subgraphs), description, + builder_.CreateVector(buffers_), metadata_buffer, *metadata); + tflite::FinishModelBuffer(builder_, model); + tflite::UpdateOpVersion(builder_.GetBufferPointer()); + tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); + + // Return serialized string for the built FlatBuffer. + return std::string(reinterpret_cast(builder_.GetBufferPointer()), + builder_.GetSize()); +} + +} // namespace + +// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer +// format. Returns false on success. +// +// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting +// the following: +// +// * Quantization +// * Ops with variable tensors +// +bool tflite::MlirToFlatBufferTranslateFunction( + ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + OpOrArgNameMapper* op_or_arg_name_mapper) { + auto maybe_translated = + Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, + emit_custom_ops, op_or_arg_name_mapper); + if (!maybe_translated) return true; + *serialized_flatbuffer = std::move(*maybe_translated); + return false; +} + +bool tflite::MlirToFlatBufferTranslateFunction( + ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops) { + OpOrArgLocNameMapper op_or_arg_name_mapper; + return MlirToFlatBufferTranslateFunction( + module, serialized_flatbuffer, emit_builtin_tflite_ops, + emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper); +} + +static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction( ModuleOp module, llvm::raw_ostream& output) { std::string serialized_flatbuffer; - std::unique_ptr op_or_arg_name_mapper; + std::unique_ptr op_or_arg_name_mapper; if (strip_debug_info) { op_or_arg_name_mapper = std::make_unique(); } else { - op_or_arg_name_mapper = - std::make_unique(); + op_or_arg_name_mapper = std::make_unique(); } if (tflite::MlirToFlatBufferTranslateFunction( module, &serialized_flatbuffer, emit_builtin_tflite_ops, @@ -162,18 +1511,8 @@ static LogicalResult MlirToFlatBufferFileTranslateFunction( return mlir::failure(); output << serialized_flatbuffer; - return success(); + return mlir::success(); } -} // namespace - -static TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( - "tflite-flatbuffer-to-mlir", - [](llvm::SourceMgr& source_mgr, MLIRContext* context) { - return FlatBufferFileToMlirTrans( - &source_mgr, context, use_external_constant, - experimental_prune_unreachable_nodes_unconditionally); - }); static TranslateFromMLIRRegistration MLIRToFlatBufferTranslate( "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction); -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.h b/tensorflow/compiler/mlir/lite/flatbuffer_translate.h similarity index 90% rename from tensorflow/compiler/mlir/lite/flatbuffer_export.h rename to tensorflow/compiler/mlir/lite/flatbuffer_translate.h index f89893d5c87..03f92ddbf03 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ #include @@ -40,4 +40,4 @@ bool MlirToFlatBufferTranslateFunction( tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); } // namespace tflite -#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h b/tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h similarity index 84% rename from tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h rename to tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h index 4e891a5b266..6c8f80d4e05 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ #include @@ -28,4 +28,4 @@ extern bool lower_tensor_list_ops; // The flag to control whether debug info gets stripped on export. extern bool strip_debug_info; -#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index d17215566a1..6f8292308a4 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -34,8 +34,8 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Module.h" // TF:llvm-project #include "mlir/Parser.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/delegates/flex/delegate.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 7557ff5223c..2f677397109 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Pass/PassManager.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc index f04dc9c2961..c05337918f2 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Pass/PassManager.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index bb82988def1..74e48cd6d91 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -28,8 +28,8 @@ limitations under the License. #include "mlir/Support/FileUtilities.h" // TF:llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 1ba0c025613..b05dcaadab2 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Support/FileUtilities.h" // TF:llvm-project #include "mlir/Transforms/Passes.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 6cd058a15d2..8ac33c906bb 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -811,8 +811,7 @@ cc_library( srcs = ["utils/error_util.cc"], hdrs = ["utils/error_util.h"], deps = [ - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:status", + "//tensorflow/core:lib", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", ], diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc index 5514a788996..60646ae764e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/lib/core/errors.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h index 1bc0a23e359..7eb30ee2c46 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h @@ -21,7 +21,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // TF:llvm-project #include "mlir/IR/Location.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project -#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/lib/core/status.h" // Error utilities for MLIR when interacting with code using Status returns. namespace mlir {