Enable Keras/RNN case via MLIR SavedModel import in TFLiteConverterV2
PiperOrigin-RevId: 304694033 Change-Id: I3c2586b92e1b4a810036ed390cb5b4d83352bef8
This commit is contained in:
parent
9a85d6fe42
commit
f21e640f0e
@ -771,6 +771,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/lite/tools/optimize:quantize_weights",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -84,8 +84,14 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
||||
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
pass_config.shape_inference = false;
|
||||
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||
quant_specs, result);
|
||||
pass_config, result);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -43,8 +43,6 @@ namespace tensorflow {
|
||||
|
||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
const string& saved_model_dir, bool saved_model_v1,
|
||||
const string& saved_model_tags, const string& saved_model_exported_names,
|
||||
string* result) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::TFL::QuantizationSpecs quant_specs;
|
||||
@ -66,13 +64,28 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
// Register all custom ops, including user-specified custom ops.
|
||||
TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
|
||||
|
||||
const bool import_saved_model = !saved_model_v1;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module,
|
||||
ImportSavedModel(import_saved_model, saved_model_v1, saved_model_dir,
|
||||
saved_model_tags, saved_model_exported_names, &context));
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||
quant_specs, result);
|
||||
auto& saved_model_tags = model_flags.saved_model_tags();
|
||||
auto& saved_model_exported_names = model_flags.saved_model_exported_names();
|
||||
std::unordered_set<std::string> tags(saved_model_tags.begin(),
|
||||
saved_model_tags.end());
|
||||
auto exported_names_in_vector = std::vector<std::string>(
|
||||
saved_model_exported_names.begin(), saved_model_exported_names.end());
|
||||
absl::Span<std::string> exported_names(exported_names_in_vector);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
ImportSavedModel(model_flags.saved_model_dir(),
|
||||
model_flags.saved_model_version(), tags,
|
||||
exported_names, &context));
|
||||
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
pass_config.shape_inference = true;
|
||||
|
||||
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||
toco_flags, std::move(module), pass_config, result);
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -28,8 +28,6 @@ namespace tensorflow {
|
||||
// status if it fails to convert the input.
|
||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
const string& saved_model_dir, bool saved_model_v1,
|
||||
const string& saved_model_tags, const string& saved_model_exported_names,
|
||||
string* result);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -261,7 +261,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
||||
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
mlir::OwningModuleRef module,
|
||||
mlir::TFL::QuantizationSpecs quant_specs,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
string* result) {
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||
@ -275,9 +275,6 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
}
|
||||
|
||||
mlir::PassManager pm(module->getContext());
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
||||
// Convert back to outlined while format for export back to flatbuffer.
|
||||
@ -288,7 +285,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm);
|
||||
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result,
|
||||
&pm);
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
// rename once we enable the new converter feature flag.
|
||||
|
@ -47,7 +47,7 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
// This will also run relevant passes as well.
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
mlir::OwningModuleRef module,
|
||||
mlir::TFL::QuantizationSpecs quant_specs,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
string* result);
|
||||
|
||||
// Give a warning for any unused flags that have been specified.
|
||||
|
@ -138,13 +138,24 @@ int main(int argc, char **argv) {
|
||||
// TODO(b/147435528): We need to test the e2e behavior once the graph freezing
|
||||
// inside mlir is done.
|
||||
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
|
||||
int saved_model_version;
|
||||
if (import_saved_model_object_graph) {
|
||||
saved_model_version = 2;
|
||||
} else {
|
||||
saved_model_version = 1;
|
||||
}
|
||||
if (input_mlir)
|
||||
module = tensorflow::errors::InvalidArgument(
|
||||
"Importing saved model should not have input_mlir set");
|
||||
module = tensorflow::ImportSavedModel(import_saved_model_object_graph,
|
||||
import_saved_model_signature_defs,
|
||||
input_file_name, saved_model_tags,
|
||||
saved_model_exported_names, &context);
|
||||
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names_vector =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
absl::Span<std::string> exported_names(exported_names_vector);
|
||||
|
||||
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
|
||||
tags, exported_names, &context);
|
||||
} else {
|
||||
module = tensorflow::LoadFromGraphdefOrMlirSource(
|
||||
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
|
||||
|
@ -160,25 +160,17 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
bool import_saved_model, bool import_saved_model_v1,
|
||||
const std::string& input_filename, const std::string& saved_model_tags,
|
||||
const std::string& saved_model_exported_names, mlir::MLIRContext* context) {
|
||||
if (import_saved_model) {
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
|
||||
const std::string& input_filename, const int saved_model_version,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||
if (saved_model_version == 2) {
|
||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, absl::Span<std::string>(exported_names), context);
|
||||
input_filename, tags, exported_names, context);
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
|
||||
return module;
|
||||
} else if (import_saved_model_v1) {
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
|
||||
} else if (saved_model_version == 1) {
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, context);
|
||||
|
||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
@ -42,9 +45,9 @@ LoadFromGraphdefOrMlirSource(
|
||||
|
||||
// Load Saved model (either v1 or v2) into MLIR.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
bool import_saved_model, bool import_saved_model_v1,
|
||||
const std::string& input_filename, const std::string& saved_model_tags,
|
||||
const std::string& saved_model_exported_names, mlir::MLIRContext* context);
|
||||
const std::string& input_filename, const int saved_model_version,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
||||
|
||||
// Taking a MLIR module in TF executor dialect and a set of parameters,
|
||||
// applies a set of passes to convert the module to TF Lite dialect and
|
||||
|
@ -257,7 +257,11 @@ def build_toco_convert_protos(input_tensors,
|
||||
target_ops=None,
|
||||
allow_nonexistent_arrays=False,
|
||||
debug_info=None,
|
||||
conversion_summary_dir=None):
|
||||
conversion_summary_dir=None,
|
||||
saved_model_dir=None,
|
||||
saved_model_version=0,
|
||||
saved_model_tags=None,
|
||||
saved_model_exported_names=None):
|
||||
"""Builds protocol buffers describing a conversion of a model using TOCO.
|
||||
|
||||
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
|
||||
@ -323,6 +327,18 @@ def build_toco_convert_protos(input_tensors,
|
||||
debug_info: `GraphDebugInfo` proto containing the stack traces for the
|
||||
original nodes referred by the converted graph.
|
||||
conversion_summary_dir: A string, the path to the generated conversion logs.
|
||||
saved_model_dir: Filepath of the saved model to be converted. This value
|
||||
will be non-empty only when the saved model import path will be used.
|
||||
Otherwises, the graph def-based conversion will be processed.
|
||||
saved_model_version: SavedModel file format version of The saved model file
|
||||
to be converted. This value will be set only when the SavedModel import
|
||||
path will be used.
|
||||
saved_model_tags: Set of string saved model tags, formatted in the
|
||||
comma-separated value. This value will be set only when the SavedModel
|
||||
import path will be used.
|
||||
saved_model_exported_names: Names to be exported (default: export all) when
|
||||
the saved model import path is on. This value will be set only when the
|
||||
SavedModel import path will be used.
|
||||
|
||||
Returns:
|
||||
model_flags, toco_flags, debug_info: three protocol buffers describing the
|
||||
@ -397,6 +413,14 @@ def build_toco_convert_protos(input_tensors,
|
||||
|
||||
model.allow_nonexistent_arrays = allow_nonexistent_arrays
|
||||
|
||||
if saved_model_dir:
|
||||
model.saved_model_dir = saved_model_dir
|
||||
model.saved_model_version = saved_model_version
|
||||
if saved_model_tags:
|
||||
model.saved_model_tags.extend(saved_model_tags)
|
||||
if saved_model_exported_names:
|
||||
model.saved_model_exported_names.extend(saved_model_exported_names)
|
||||
|
||||
return model, toco, debug_info
|
||||
|
||||
|
||||
|
@ -74,6 +74,7 @@ from tensorflow.python.lib.io import file_io as _file_io
|
||||
from tensorflow.python.saved_model import signature_constants as _signature_constants
|
||||
from tensorflow.python.saved_model import tag_constants as _tag_constants
|
||||
from tensorflow.python.saved_model.load import load as _load
|
||||
from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
|
||||
from tensorflow.python.util import deprecation as _deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||
|
||||
@ -285,6 +286,10 @@ class TFLiteConverterBase(object):
|
||||
# The 'GraphDebugInfo' contains the stack traces of all the original nodes
|
||||
# in the `GraphDef` to the converter.
|
||||
self._debug_info = None
|
||||
self._saved_model_dir = None
|
||||
self._saved_model_tags = None
|
||||
self._saved_model_version = None
|
||||
self._saved_model_exported_names = []
|
||||
|
||||
def _grappler_config(self, optimizers=None):
|
||||
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
|
||||
@ -346,8 +351,46 @@ class TFLiteConverterBase(object):
|
||||
"target_ops": self.target_spec.supported_ops,
|
||||
"enable_mlir_converter": self.experimental_new_converter,
|
||||
}
|
||||
|
||||
if self._saved_model_dir:
|
||||
args.update({
|
||||
"saved_model_dir": self._saved_model_dir,
|
||||
"saved_model_version": self._saved_model_version,
|
||||
"saved_model_tags": self._saved_model_tags,
|
||||
"saved_model_exported_names": self._saved_model_exported_names,
|
||||
})
|
||||
|
||||
return args
|
||||
|
||||
def _contains_function_with_implements_attr(self, saved_model_proto):
|
||||
meta_graph = saved_model_proto.meta_graphs[0]
|
||||
for function in meta_graph.graph_def.library.function:
|
||||
if function.attr.get("_implements", None) or function.attr.get(
|
||||
"api_implements", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _parse_saved_model_args(self):
|
||||
"""Parses SavedModel arguments from the given Keras/RNN SavedModel."""
|
||||
if self._saved_model_dir:
|
||||
try:
|
||||
saved_model_proto, _ = (
|
||||
_parse_saved_model_with_debug_info(self._saved_model_dir))
|
||||
except OSError:
|
||||
# If it fails to read the given saved model, it will fall back to the
|
||||
# frozen graph def path.
|
||||
self._saved_model_dir = None
|
||||
return
|
||||
if not self._contains_function_with_implements_attr(saved_model_proto):
|
||||
self._saved_model_dir = None
|
||||
else:
|
||||
self._saved_model_exported_names = []
|
||||
self._saved_model_version = saved_model_proto.saved_model_schema_version
|
||||
if self._saved_model_version not in [1, 2]:
|
||||
raise ValueError(
|
||||
"SavedModel file format({0}) is not supported".format(
|
||||
self._saved_model_version))
|
||||
|
||||
|
||||
@_tf_export("lite.TFLiteConverter", v1=[])
|
||||
class TFLiteConverterV2(TFLiteConverterBase):
|
||||
@ -387,7 +430,11 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, funcs, trackable_obj=None):
|
||||
def __init__(self,
|
||||
funcs,
|
||||
trackable_obj=None,
|
||||
saved_model_dir=None,
|
||||
saved_model_tags=None):
|
||||
"""Constructor for TFLiteConverter.
|
||||
|
||||
Args:
|
||||
@ -398,10 +445,19 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
get garbage collected since functions have a weak reference to
|
||||
Variables. This is only required when the tf.AutoTrackable object is not
|
||||
maintained by the user (e.g. `from_saved_model`).
|
||||
saved_model_dir: Directory of the SavedModel. This argument can be null
|
||||
when it creates via the from_keras_model and from_concrete_function
|
||||
methods.
|
||||
saved_model_tags: Set of tags identifying the MetaGraphDef within the
|
||||
SavedModel to analyze. All tags in the tag set must be present. (default
|
||||
set(SERVING)). This argument will be available when the saved model dir
|
||||
argument is set.
|
||||
"""
|
||||
super(TFLiteConverterV2, self).__init__()
|
||||
self._funcs = funcs
|
||||
self._trackable_obj = trackable_obj
|
||||
self._saved_model_dir = saved_model_dir
|
||||
self._saved_model_tags = saved_model_tags
|
||||
|
||||
@classmethod
|
||||
def from_concrete_functions(cls, funcs):
|
||||
@ -463,6 +519,9 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
|
||||
# Ensures any graphs created in Eager mode are able to run. This is required
|
||||
# in order to create a tf.estimator.Exporter that exports a TFLite model.
|
||||
if tags is None:
|
||||
tags = set([_tag_constants.SERVING])
|
||||
|
||||
with context.eager_mode():
|
||||
saved_model = _load(saved_model_dir, tags)
|
||||
if not signature_keys:
|
||||
@ -475,7 +534,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
"'{}'.".format(key, ",".join(saved_model.signatures)))
|
||||
funcs.append(saved_model.signatures[key])
|
||||
|
||||
return cls(funcs, saved_model)
|
||||
return cls(funcs, saved_model, saved_model_dir, tags)
|
||||
|
||||
@classmethod
|
||||
def from_keras_model(cls, model):
|
||||
@ -521,6 +580,9 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
"ConcreteFunction. Converting multiple functions is "
|
||||
"under development.")
|
||||
|
||||
# Parses SavedModel argument.
|
||||
self._parse_saved_model_args()
|
||||
|
||||
# graph_def is used here to preserve the node bug information
|
||||
frozen_func, graph_def = (
|
||||
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
|
||||
@ -693,6 +755,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
the dataset to evaluate different optimizations.
|
||||
experimental_new_converter: Experimental flag, subject to change.
|
||||
Enables MLIR-based conversion instead of TOCO conversion.
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
@ -725,7 +788,9 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
output_tensors,
|
||||
input_arrays_with_shape=None,
|
||||
output_arrays=None,
|
||||
experimental_debug_info_func=None):
|
||||
experimental_debug_info_func=None,
|
||||
saved_model_dir=None,
|
||||
saved_model_tags=None):
|
||||
"""Constructor for TFLiteConverter.
|
||||
|
||||
Args:
|
||||
@ -743,6 +808,13 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
`output_tensors` are None. (default None)
|
||||
experimental_debug_info_func: An experimental function to retrieve the
|
||||
graph debug info for a set of nodes from the `graph_def`.
|
||||
saved_model_dir: Directory of the SavedModel. This argument can be null
|
||||
when it creates via the from_keras_model and from_concrete_function
|
||||
methods.
|
||||
saved_model_tags: Set of tags identifying the MetaGraphDef within the
|
||||
SavedModel to analyze. All tags in the tag set must be present. (default
|
||||
set(SERVING)). This argument will be available when the saved model dir
|
||||
argument is set.
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid arguments.
|
||||
@ -766,6 +838,8 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
self.conversion_summary_dir = None
|
||||
self._debug_info_func = experimental_debug_info_func
|
||||
self._custom_opdefs = None
|
||||
self._saved_model_dir = saved_model_dir
|
||||
self._saved_model_tags = saved_model_tags
|
||||
|
||||
# Attributes are used by models that cannot be loaded into TensorFlow.
|
||||
if not self._has_valid_tensors():
|
||||
@ -928,7 +1002,9 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
graph_def=result[0],
|
||||
input_tensors=result[1],
|
||||
output_tensors=result[2],
|
||||
experimental_debug_info_func=_build_debug_info_func(result[3]))
|
||||
experimental_debug_info_func=_build_debug_info_func(result[3]),
|
||||
saved_model_dir=saved_model_dir,
|
||||
saved_model_tags=tag_set)
|
||||
|
||||
@classmethod
|
||||
def from_keras_model_file(cls,
|
||||
@ -1059,6 +1135,9 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
Input shape is not specified.
|
||||
None value for dimension in input_tensor.
|
||||
"""
|
||||
# Parses SavedModel argument.
|
||||
self._parse_saved_model_args()
|
||||
|
||||
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
|
||||
self.representative_dataset, self._graph_def)
|
||||
|
||||
|
@ -12,10 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto2";
|
||||
import "tensorflow/lite/toco/types.proto";
|
||||
|
||||
package toco;
|
||||
|
||||
import "tensorflow/lite/toco/types.proto";
|
||||
|
||||
message InputArrayShape {
|
||||
repeated int32 dims = 2;
|
||||
}
|
||||
@ -130,7 +131,7 @@ message ArraysExtraInfo {
|
||||
// optional int32 input_dims = 11 [ default = 4];
|
||||
// repeated int32 input_shape = 13;
|
||||
//
|
||||
// Next ID to USE: 20.
|
||||
// Next ID to USE: 24.
|
||||
message ModelFlags {
|
||||
// Information about the input arrays, i.e. the arrays from which input
|
||||
// activations will be read.
|
||||
@ -181,4 +182,22 @@ message ModelFlags {
|
||||
// When set to false, toco will not change the input ranges and the output
|
||||
// ranges of concat operator to the overlap of all input ranges.
|
||||
optional bool change_concat_input_ranges = 19 [default = true];
|
||||
|
||||
// Filepath of the saved model to be converted. This value will be non-empty
|
||||
// only when the saved model import path will be used. Otherwise, the graph
|
||||
// def-based conversion will be processed.
|
||||
optional string saved_model_dir = 20;
|
||||
|
||||
// SavedModel file format version of The saved model file to be converted.
|
||||
// This value will be set only when the SavedModel import path will be used.
|
||||
optional int32 saved_model_version = 21;
|
||||
|
||||
// Set of string saved model tags, formatted in the comma-separated value.
|
||||
// This value will be set only when the SavedModel import path will be used.
|
||||
repeated string saved_model_tags = 22;
|
||||
|
||||
// Names to be exported (default: export all) when the saved model import path
|
||||
// is on. This value will be set only when the SavedModel import path will be
|
||||
// used.
|
||||
repeated string saved_model_exported_names = 23;
|
||||
}
|
||||
|
@ -49,6 +49,7 @@ cc_library(
|
||||
"//tensorflow/lite/toco:tooling_util",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer",
|
||||
"//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer",
|
||||
] + select({
|
||||
# This is required when running `tflite_convert` from `bazel`.
|
||||
# It requires to link with TensorFlow Ops to get the op definitions.
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||
#include "tensorflow/lite/toco/import_tensorflow.h"
|
||||
@ -144,13 +145,6 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::GraphDef graph_def;
|
||||
if (!graph_def.ParseFromString(input_contents_txt)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Failed to convert GraphDef to Python String.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto& dump_options = *GraphVizDumpOptions::singleton();
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
|
||||
@ -165,13 +159,25 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||
|
||||
// Convert model.
|
||||
if (enable_mlir_converter) {
|
||||
status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
|
||||
model_flags, toco_flags, debug_info, graph_def,
|
||||
&output_file_contents_txt);
|
||||
if (!toco_flags.conversion_summary_dir().empty()) {
|
||||
PopulateConversionLogHelper(model_flags, &toco_flags, input_contents_txt,
|
||||
output_file_contents_txt,
|
||||
status.error_message(), &dump_options);
|
||||
if (!model_flags.saved_model_dir().empty()) {
|
||||
status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer(
|
||||
model_flags, toco_flags, &output_file_contents_txt);
|
||||
} else {
|
||||
tensorflow::GraphDef graph_def;
|
||||
if (!graph_def.ParseFromString(input_contents_txt)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Failed to convert GraphDef to Python String.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
|
||||
model_flags, toco_flags, debug_info, graph_def,
|
||||
&output_file_contents_txt);
|
||||
if (!toco_flags.conversion_summary_dir().empty()) {
|
||||
PopulateConversionLogHelper(
|
||||
model_flags, &toco_flags, input_contents_txt,
|
||||
output_file_contents_txt, status.error_message(), &dump_options);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
status = Convert(input_contents_txt, toco_flags, model_flags,
|
||||
|
@ -5,7 +5,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "convert"
|
||||
|
@ -5,7 +5,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'funcs\', \'trackable_obj\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'funcs\', \'trackable_obj\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "convert"
|
||||
|
Loading…
Reference in New Issue
Block a user