diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 32a977416ae..789d06b8ac9 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -771,6 +771,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:support", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 0a3f0eb3518..1165561cb71 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -84,8 +84,14 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, TF_ASSIGN_OR_RETURN( auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context)); + mlir::TFL::PassConfig pass_config(quant_specs); + bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); + pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; + pass_config.lower_tensor_list_ops = true; + pass_config.shape_inference = false; + return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), - quant_specs, result); + pass_config, result); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index f8435d17c8d..681773a7e6b 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -43,8 +43,6 @@ namespace tensorflow { Status ConvertSavedModelToTFLiteFlatBuffer( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, - const string& saved_model_dir, bool saved_model_v1, - const string& saved_model_tags, const string& saved_model_exported_names, string* result) { mlir::MLIRContext context; mlir::TFL::QuantizationSpecs quant_specs; @@ -66,13 +64,28 @@ Status ConvertSavedModelToTFLiteFlatBuffer( // Register all custom ops, including user-specified custom ops. TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags)); - const bool import_saved_model = !saved_model_v1; - TF_ASSIGN_OR_RETURN( - auto module, - ImportSavedModel(import_saved_model, saved_model_v1, saved_model_dir, - saved_model_tags, saved_model_exported_names, &context)); - return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module), - quant_specs, result); + auto& saved_model_tags = model_flags.saved_model_tags(); + auto& saved_model_exported_names = model_flags.saved_model_exported_names(); + std::unordered_set<std::string> tags(saved_model_tags.begin(), + saved_model_tags.end()); + auto exported_names_in_vector = std::vector<std::string>( + saved_model_exported_names.begin(), saved_model_exported_names.end()); + absl::Span<std::string> exported_names(exported_names_in_vector); + + TF_ASSIGN_OR_RETURN(auto module, + ImportSavedModel(model_flags.saved_model_dir(), + model_flags.saved_model_version(), tags, + exported_names, &context)); + + mlir::TFL::PassConfig pass_config(quant_specs); + bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); + pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; + pass_config.lower_tensor_list_ops = true; + pass_config.shape_inference = true; + + auto status = internal::ConvertMLIRToTFLiteFlatBuffer( + toco_flags, std::move(module), pass_config, result); + return status; } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h index dea5603dad0..ed339ca64b9 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h @@ -28,8 +28,6 @@ namespace tensorflow { // status if it fails to convert the input. Status ConvertSavedModelToTFLiteFlatBuffer( const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, - const string& saved_model_dir, bool saved_model_v1, - const string& saved_model_tags, const string& saved_model_exported_names, string* result); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index ae342dd49ae..21761af382d 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -261,7 +261,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, - mlir::TFL::QuantizationSpecs quant_specs, + const mlir::TFL::PassConfig& pass_config, string* result) { bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); bool emit_select_tf_ops = toco_flags.enable_select_tf_ops(); @@ -275,9 +275,6 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, } mlir::PassManager pm(module->getContext()); - mlir::TFL::PassConfig pass_config(quant_specs); - pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; - pass_config.lower_tensor_list_ops = true; tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); // Convert back to outlined while format for export back to flatbuffer. @@ -288,7 +285,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm); + emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result, + &pm); if (toco_flags.has_dump_graphviz_dir()) { TF_RETURN_IF_ERROR(DumpOpGraphToFile( // rename once we enable the new converter feature flag. diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index 96c2096e469..3ea36e5eb1d 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -47,7 +47,7 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, // This will also run relevant passes as well. Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, - mlir::TFL::QuantizationSpecs quant_specs, + const mlir::TFL::PassConfig& pass_config, string* result); // Give a warning for any unused flags that have been specified. diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 762bd8c8ed2..f7b3bf87222 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -138,13 +138,24 @@ int main(int argc, char **argv) { // TODO(b/147435528): We need to test the e2e behavior once the graph freezing // inside mlir is done. if (import_saved_model_object_graph || import_saved_model_signature_defs) { + int saved_model_version; + if (import_saved_model_object_graph) { + saved_model_version = 2; + } else { + saved_model_version = 1; + } if (input_mlir) module = tensorflow::errors::InvalidArgument( "Importing saved model should not have input_mlir set"); - module = tensorflow::ImportSavedModel(import_saved_model_object_graph, - import_saved_model_signature_defs, - input_file_name, saved_model_tags, - saved_model_exported_names, &context); + + std::unordered_set<std::string> tags = + absl::StrSplit(saved_model_tags, ','); + std::vector<std::string> exported_names_vector = + absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + absl::Span<std::string> exported_names(exported_names_vector); + + module = tensorflow::ImportSavedModel(input_file_name, saved_model_version, + tags, exported_names, &context); } else { module = tensorflow::LoadFromGraphdefOrMlirSource( input_file_name, input_mlir, use_splatted_constant, custom_opdefs, diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 7c0a91d6d4e..aacc1ad2fd6 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -160,25 +160,17 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( } StatusOr<mlir::OwningModuleRef> ImportSavedModel( - bool import_saved_model, bool import_saved_model_v1, - const std::string& input_filename, const std::string& saved_model_tags, - const std::string& saved_model_exported_names, mlir::MLIRContext* context) { - if (import_saved_model) { - std::unordered_set<std::string> tags = - absl::StrSplit(saved_model_tags, ','); - std::vector<std::string> exported_names = - absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); - + const std::string& input_filename, const int saved_model_version, + const std::unordered_set<std::string>& tags, + absl::Span<std::string> exported_names, mlir::MLIRContext* context) { + if (saved_model_version == 2) { auto module = tensorflow::SavedModelObjectGraphToMlirImport( - input_filename, tags, absl::Span<std::string>(exported_names), context); + input_filename, tags, exported_names, context); if (!module) return tensorflow::errors::InvalidArgument("fail to open input file"); return module; - } else if (import_saved_model_v1) { - std::unordered_set<std::string> tags = - absl::StrSplit(saved_model_tags, ','); - + } else if (saved_model_version == 1) { auto module = tensorflow::SavedModelSignatureDefsToMlirImport( input_filename, tags, context); diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index c93f8a6d416..d2c31a6b972 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ +#include <unordered_set> + +#include "absl/types/span.h" #include "llvm/Support/SourceMgr.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project @@ -42,9 +45,9 @@ LoadFromGraphdefOrMlirSource( // Load Saved model (either v1 or v2) into MLIR. stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel( - bool import_saved_model, bool import_saved_model_v1, - const std::string& input_filename, const std::string& saved_model_tags, - const std::string& saved_model_exported_names, mlir::MLIRContext* context); + const std::string& input_filename, const int saved_model_version, + const std::unordered_set<std::string>& tags, + absl::Span<std::string> exported_names, mlir::MLIRContext* context); // Taking a MLIR module in TF executor dialect and a set of parameters, // applies a set of passes to convert the module to TF Lite dialect and diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 1744defea94..89b0a91f665 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -257,7 +257,11 @@ def build_toco_convert_protos(input_tensors, target_ops=None, allow_nonexistent_arrays=False, debug_info=None, - conversion_summary_dir=None): + conversion_summary_dir=None, + saved_model_dir=None, + saved_model_version=0, + saved_model_tags=None, + saved_model_exported_names=None): """Builds protocol buffers describing a conversion of a model using TOCO. Typically this is to convert from TensorFlow GraphDef to TFLite, in which @@ -323,6 +327,18 @@ def build_toco_convert_protos(input_tensors, debug_info: `GraphDebugInfo` proto containing the stack traces for the original nodes referred by the converted graph. conversion_summary_dir: A string, the path to the generated conversion logs. + saved_model_dir: Filepath of the saved model to be converted. This value + will be non-empty only when the saved model import path will be used. + Otherwises, the graph def-based conversion will be processed. + saved_model_version: SavedModel file format version of The saved model file + to be converted. This value will be set only when the SavedModel import + path will be used. + saved_model_tags: Set of string saved model tags, formatted in the + comma-separated value. This value will be set only when the SavedModel + import path will be used. + saved_model_exported_names: Names to be exported (default: export all) when + the saved model import path is on. This value will be set only when the + SavedModel import path will be used. Returns: model_flags, toco_flags, debug_info: three protocol buffers describing the @@ -397,6 +413,14 @@ def build_toco_convert_protos(input_tensors, model.allow_nonexistent_arrays = allow_nonexistent_arrays + if saved_model_dir: + model.saved_model_dir = saved_model_dir + model.saved_model_version = saved_model_version + if saved_model_tags: + model.saved_model_tags.extend(saved_model_tags) + if saved_model_exported_names: + model.saved_model_exported_names.extend(saved_model_exported_names) + return model, toco, debug_info diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index ba9e6e0bd39..97d3f2a1ec6 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -74,6 +74,7 @@ from tensorflow.python.lib.io import file_io as _file_io from tensorflow.python.saved_model import signature_constants as _signature_constants from tensorflow.python.saved_model import tag_constants as _tag_constants from tensorflow.python.saved_model.load import load as _load +from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info from tensorflow.python.util import deprecation as _deprecation from tensorflow.python.util.tf_export import tf_export as _tf_export @@ -285,6 +286,10 @@ class TFLiteConverterBase(object): # The 'GraphDebugInfo' contains the stack traces of all the original nodes # in the `GraphDef` to the converter. self._debug_info = None + self._saved_model_dir = None + self._saved_model_tags = None + self._saved_model_version = None + self._saved_model_exported_names = [] def _grappler_config(self, optimizers=None): """Creates a tf.compat.v1.ConfigProto for configuring Grappler. @@ -346,8 +351,46 @@ class TFLiteConverterBase(object): "target_ops": self.target_spec.supported_ops, "enable_mlir_converter": self.experimental_new_converter, } + + if self._saved_model_dir: + args.update({ + "saved_model_dir": self._saved_model_dir, + "saved_model_version": self._saved_model_version, + "saved_model_tags": self._saved_model_tags, + "saved_model_exported_names": self._saved_model_exported_names, + }) + return args + def _contains_function_with_implements_attr(self, saved_model_proto): + meta_graph = saved_model_proto.meta_graphs[0] + for function in meta_graph.graph_def.library.function: + if function.attr.get("_implements", None) or function.attr.get( + "api_implements", None): + return True + return False + + def _parse_saved_model_args(self): + """Parses SavedModel arguments from the given Keras/RNN SavedModel.""" + if self._saved_model_dir: + try: + saved_model_proto, _ = ( + _parse_saved_model_with_debug_info(self._saved_model_dir)) + except OSError: + # If it fails to read the given saved model, it will fall back to the + # frozen graph def path. + self._saved_model_dir = None + return + if not self._contains_function_with_implements_attr(saved_model_proto): + self._saved_model_dir = None + else: + self._saved_model_exported_names = [] + self._saved_model_version = saved_model_proto.saved_model_schema_version + if self._saved_model_version not in [1, 2]: + raise ValueError( + "SavedModel file format({0}) is not supported".format( + self._saved_model_version)) + @_tf_export("lite.TFLiteConverter", v1=[]) class TFLiteConverterV2(TFLiteConverterBase): @@ -387,7 +430,11 @@ class TFLiteConverterV2(TFLiteConverterBase): ``` """ - def __init__(self, funcs, trackable_obj=None): + def __init__(self, + funcs, + trackable_obj=None, + saved_model_dir=None, + saved_model_tags=None): """Constructor for TFLiteConverter. Args: @@ -398,10 +445,19 @@ class TFLiteConverterV2(TFLiteConverterBase): get garbage collected since functions have a weak reference to Variables. This is only required when the tf.AutoTrackable object is not maintained by the user (e.g. `from_saved_model`). + saved_model_dir: Directory of the SavedModel. This argument can be null + when it creates via the from_keras_model and from_concrete_function + methods. + saved_model_tags: Set of tags identifying the MetaGraphDef within the + SavedModel to analyze. All tags in the tag set must be present. (default + set(SERVING)). This argument will be available when the saved model dir + argument is set. """ super(TFLiteConverterV2, self).__init__() self._funcs = funcs self._trackable_obj = trackable_obj + self._saved_model_dir = saved_model_dir + self._saved_model_tags = saved_model_tags @classmethod def from_concrete_functions(cls, funcs): @@ -463,6 +519,9 @@ class TFLiteConverterV2(TFLiteConverterBase): # Ensures any graphs created in Eager mode are able to run. This is required # in order to create a tf.estimator.Exporter that exports a TFLite model. + if tags is None: + tags = set([_tag_constants.SERVING]) + with context.eager_mode(): saved_model = _load(saved_model_dir, tags) if not signature_keys: @@ -475,7 +534,7 @@ class TFLiteConverterV2(TFLiteConverterBase): "'{}'.".format(key, ",".join(saved_model.signatures))) funcs.append(saved_model.signatures[key]) - return cls(funcs, saved_model) + return cls(funcs, saved_model, saved_model_dir, tags) @classmethod def from_keras_model(cls, model): @@ -521,6 +580,9 @@ class TFLiteConverterV2(TFLiteConverterBase): "ConcreteFunction. Converting multiple functions is " "under development.") + # Parses SavedModel argument. + self._parse_saved_model_args() + # graph_def is used here to preserve the node bug information frozen_func, graph_def = ( _convert_to_constants.convert_variables_to_constants_v2_as_graph( @@ -693,6 +755,7 @@ class TFLiteConverter(TFLiteConverterBase): the dataset to evaluate different optimizations. experimental_new_converter: Experimental flag, subject to change. Enables MLIR-based conversion instead of TOCO conversion. + Example usage: ```python @@ -725,7 +788,9 @@ class TFLiteConverter(TFLiteConverterBase): output_tensors, input_arrays_with_shape=None, output_arrays=None, - experimental_debug_info_func=None): + experimental_debug_info_func=None, + saved_model_dir=None, + saved_model_tags=None): """Constructor for TFLiteConverter. Args: @@ -743,6 +808,13 @@ class TFLiteConverter(TFLiteConverterBase): `output_tensors` are None. (default None) experimental_debug_info_func: An experimental function to retrieve the graph debug info for a set of nodes from the `graph_def`. + saved_model_dir: Directory of the SavedModel. This argument can be null + when it creates via the from_keras_model and from_concrete_function + methods. + saved_model_tags: Set of tags identifying the MetaGraphDef within the + SavedModel to analyze. All tags in the tag set must be present. (default + set(SERVING)). This argument will be available when the saved model dir + argument is set. Raises: ValueError: Invalid arguments. @@ -766,6 +838,8 @@ class TFLiteConverter(TFLiteConverterBase): self.conversion_summary_dir = None self._debug_info_func = experimental_debug_info_func self._custom_opdefs = None + self._saved_model_dir = saved_model_dir + self._saved_model_tags = saved_model_tags # Attributes are used by models that cannot be loaded into TensorFlow. if not self._has_valid_tensors(): @@ -928,7 +1002,9 @@ class TFLiteConverter(TFLiteConverterBase): graph_def=result[0], input_tensors=result[1], output_tensors=result[2], - experimental_debug_info_func=_build_debug_info_func(result[3])) + experimental_debug_info_func=_build_debug_info_func(result[3]), + saved_model_dir=saved_model_dir, + saved_model_tags=tag_set) @classmethod def from_keras_model_file(cls, @@ -1059,6 +1135,9 @@ class TFLiteConverter(TFLiteConverterBase): Input shape is not specified. None value for dimension in input_tensor. """ + # Parses SavedModel argument. + self._parse_saved_model_args() + quant_mode = QuantizationMode(self.optimizations, self.target_spec, self.representative_dataset, self._graph_def) diff --git a/tensorflow/lite/toco/model_flags.proto b/tensorflow/lite/toco/model_flags.proto index dfc425073f5..7fd42e4afd8 100644 --- a/tensorflow/lite/toco/model_flags.proto +++ b/tensorflow/lite/toco/model_flags.proto @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; -import "tensorflow/lite/toco/types.proto"; package toco; +import "tensorflow/lite/toco/types.proto"; + message InputArrayShape { repeated int32 dims = 2; } @@ -130,7 +131,7 @@ message ArraysExtraInfo { // optional int32 input_dims = 11 [ default = 4]; // repeated int32 input_shape = 13; // -// Next ID to USE: 20. +// Next ID to USE: 24. message ModelFlags { // Information about the input arrays, i.e. the arrays from which input // activations will be read. @@ -181,4 +182,22 @@ message ModelFlags { // When set to false, toco will not change the input ranges and the output // ranges of concat operator to the overlap of all input ranges. optional bool change_concat_input_ranges = 19 [default = true]; + + // Filepath of the saved model to be converted. This value will be non-empty + // only when the saved model import path will be used. Otherwise, the graph + // def-based conversion will be processed. + optional string saved_model_dir = 20; + + // SavedModel file format version of The saved model file to be converted. + // This value will be set only when the SavedModel import path will be used. + optional int32 saved_model_version = 21; + + // Set of string saved model tags, formatted in the comma-separated value. + // This value will be set only when the SavedModel import path will be used. + repeated string saved_model_tags = 22; + + // Names to be exported (default: export all) when the saved model import path + // is on. This value will be set only when the SavedModel import path will be + // used. + repeated string saved_model_exported_names = 23; } diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index b8a00b90a06..236913c9678 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -49,6 +49,7 @@ cc_library( "//tensorflow/lite/toco:tooling_util", "//tensorflow/core:protos_all_cc", "//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer", + "//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer", ] + select({ # This is required when running `tflite_convert` from `bazel`. # It requires to link with TensorFlow Ops to get the op definitions. diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index 31de4cfc726..667754e956f 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -21,6 +21,7 @@ limitations under the License. #include "google/protobuf/text_format.h" #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h" +#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/lite/toco/import_tensorflow.h" @@ -144,13 +145,6 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, } } - tensorflow::GraphDef graph_def; - if (!graph_def.ParseFromString(input_contents_txt)) { - PyErr_SetString(PyExc_ValueError, - "Failed to convert GraphDef to Python String."); - return nullptr; - } - auto& dump_options = *GraphVizDumpOptions::singleton(); if (toco_flags.has_dump_graphviz_dir()) { dump_options.dump_graphviz = toco_flags.dump_graphviz_dir(); @@ -165,13 +159,25 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, // Convert model. if (enable_mlir_converter) { - status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer( - model_flags, toco_flags, debug_info, graph_def, - &output_file_contents_txt); - if (!toco_flags.conversion_summary_dir().empty()) { - PopulateConversionLogHelper(model_flags, &toco_flags, input_contents_txt, - output_file_contents_txt, - status.error_message(), &dump_options); + if (!model_flags.saved_model_dir().empty()) { + status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer( + model_flags, toco_flags, &output_file_contents_txt); + } else { + tensorflow::GraphDef graph_def; + if (!graph_def.ParseFromString(input_contents_txt)) { + PyErr_SetString(PyExc_ValueError, + "Failed to convert GraphDef to Python String."); + return nullptr; + } + + status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer( + model_flags, toco_flags, debug_info, graph_def, + &output_file_contents_txt); + if (!toco_flags.conversion_summary_dir().empty()) { + PopulateConversionLogHelper( + model_flags, &toco_flags, input_contents_txt, + output_file_contents_txt, status.error_message(), &dump_options); + } } } else { status = Convert(input_contents_txt, toco_flags, model_flags, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt index db76bb3f4b3..0c43fc556aa 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "<type \'object\'>" member_method { name: "__init__" - argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "convert" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt index 63a6667c0b2..c575283b74d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "<type \'object\'>" member_method { name: "__init__" - argspec: "args=[\'self\', \'funcs\', \'trackable_obj\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'funcs\', \'trackable_obj\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "convert"