Enable Keras/RNN case via MLIR SavedModel import in TFLiteConverterV2
PiperOrigin-RevId: 304694033 Change-Id: I3c2586b92e1b4a810036ed390cb5b4d83352bef8
This commit is contained in:
parent
9a85d6fe42
commit
f21e640f0e
tensorflow
@ -771,6 +771,7 @@ cc_library(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/lite/tools/optimize:quantize_weights",
|
"//tensorflow/lite/tools/optimize:quantize_weights",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:AllPassesAndDialects",
|
"@llvm-project//mlir:AllPassesAndDialects",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -84,8 +84,14 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
||||||
|
|
||||||
|
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||||
|
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||||
|
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||||
|
pass_config.lower_tensor_list_ops = true;
|
||||||
|
pass_config.shape_inference = false;
|
||||||
|
|
||||||
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||||
quant_specs, result);
|
pass_config, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -43,8 +43,6 @@ namespace tensorflow {
|
|||||||
|
|
||||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||||
const string& saved_model_dir, bool saved_model_v1,
|
|
||||||
const string& saved_model_tags, const string& saved_model_exported_names,
|
|
||||||
string* result) {
|
string* result) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
mlir::TFL::QuantizationSpecs quant_specs;
|
mlir::TFL::QuantizationSpecs quant_specs;
|
||||||
@ -66,13 +64,28 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
|||||||
// Register all custom ops, including user-specified custom ops.
|
// Register all custom ops, including user-specified custom ops.
|
||||||
TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
|
TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
|
||||||
|
|
||||||
const bool import_saved_model = !saved_model_v1;
|
auto& saved_model_tags = model_flags.saved_model_tags();
|
||||||
TF_ASSIGN_OR_RETURN(
|
auto& saved_model_exported_names = model_flags.saved_model_exported_names();
|
||||||
auto module,
|
std::unordered_set<std::string> tags(saved_model_tags.begin(),
|
||||||
ImportSavedModel(import_saved_model, saved_model_v1, saved_model_dir,
|
saved_model_tags.end());
|
||||||
saved_model_tags, saved_model_exported_names, &context));
|
auto exported_names_in_vector = std::vector<std::string>(
|
||||||
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
saved_model_exported_names.begin(), saved_model_exported_names.end());
|
||||||
quant_specs, result);
|
absl::Span<std::string> exported_names(exported_names_in_vector);
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto module,
|
||||||
|
ImportSavedModel(model_flags.saved_model_dir(),
|
||||||
|
model_flags.saved_model_version(), tags,
|
||||||
|
exported_names, &context));
|
||||||
|
|
||||||
|
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||||
|
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||||
|
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||||
|
pass_config.lower_tensor_list_ops = true;
|
||||||
|
pass_config.shape_inference = true;
|
||||||
|
|
||||||
|
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||||
|
toco_flags, std::move(module), pass_config, result);
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -28,8 +28,6 @@ namespace tensorflow {
|
|||||||
// status if it fails to convert the input.
|
// status if it fails to convert the input.
|
||||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||||
const string& saved_model_dir, bool saved_model_v1,
|
|
||||||
const string& saved_model_tags, const string& saved_model_exported_names,
|
|
||||||
string* result);
|
string* result);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -261,7 +261,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
|||||||
|
|
||||||
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||||
mlir::OwningModuleRef module,
|
mlir::OwningModuleRef module,
|
||||||
mlir::TFL::QuantizationSpecs quant_specs,
|
const mlir::TFL::PassConfig& pass_config,
|
||||||
string* result) {
|
string* result) {
|
||||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||||
@ -275,9 +275,6 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
|||||||
}
|
}
|
||||||
|
|
||||||
mlir::PassManager pm(module->getContext());
|
mlir::PassManager pm(module->getContext());
|
||||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
|
||||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
|
||||||
pass_config.lower_tensor_list_ops = true;
|
|
||||||
|
|
||||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
||||||
// Convert back to outlined while format for export back to flatbuffer.
|
// Convert back to outlined while format for export back to flatbuffer.
|
||||||
@ -288,7 +285,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
|||||||
|
|
||||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||||
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm);
|
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result,
|
||||||
|
&pm);
|
||||||
if (toco_flags.has_dump_graphviz_dir()) {
|
if (toco_flags.has_dump_graphviz_dir()) {
|
||||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||||
// rename once we enable the new converter feature flag.
|
// rename once we enable the new converter feature flag.
|
||||||
|
@ -47,7 +47,7 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
|||||||
// This will also run relevant passes as well.
|
// This will also run relevant passes as well.
|
||||||
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||||
mlir::OwningModuleRef module,
|
mlir::OwningModuleRef module,
|
||||||
mlir::TFL::QuantizationSpecs quant_specs,
|
const mlir::TFL::PassConfig& pass_config,
|
||||||
string* result);
|
string* result);
|
||||||
|
|
||||||
// Give a warning for any unused flags that have been specified.
|
// Give a warning for any unused flags that have been specified.
|
||||||
|
@ -138,13 +138,24 @@ int main(int argc, char **argv) {
|
|||||||
// TODO(b/147435528): We need to test the e2e behavior once the graph freezing
|
// TODO(b/147435528): We need to test the e2e behavior once the graph freezing
|
||||||
// inside mlir is done.
|
// inside mlir is done.
|
||||||
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
|
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
|
||||||
|
int saved_model_version;
|
||||||
|
if (import_saved_model_object_graph) {
|
||||||
|
saved_model_version = 2;
|
||||||
|
} else {
|
||||||
|
saved_model_version = 1;
|
||||||
|
}
|
||||||
if (input_mlir)
|
if (input_mlir)
|
||||||
module = tensorflow::errors::InvalidArgument(
|
module = tensorflow::errors::InvalidArgument(
|
||||||
"Importing saved model should not have input_mlir set");
|
"Importing saved model should not have input_mlir set");
|
||||||
module = tensorflow::ImportSavedModel(import_saved_model_object_graph,
|
|
||||||
import_saved_model_signature_defs,
|
std::unordered_set<std::string> tags =
|
||||||
input_file_name, saved_model_tags,
|
absl::StrSplit(saved_model_tags, ',');
|
||||||
saved_model_exported_names, &context);
|
std::vector<std::string> exported_names_vector =
|
||||||
|
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||||
|
absl::Span<std::string> exported_names(exported_names_vector);
|
||||||
|
|
||||||
|
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
|
||||||
|
tags, exported_names, &context);
|
||||||
} else {
|
} else {
|
||||||
module = tensorflow::LoadFromGraphdefOrMlirSource(
|
module = tensorflow::LoadFromGraphdefOrMlirSource(
|
||||||
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
|
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
|
||||||
|
@ -160,25 +160,17 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||||
bool import_saved_model, bool import_saved_model_v1,
|
const std::string& input_filename, const int saved_model_version,
|
||||||
const std::string& input_filename, const std::string& saved_model_tags,
|
const std::unordered_set<std::string>& tags,
|
||||||
const std::string& saved_model_exported_names, mlir::MLIRContext* context) {
|
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||||
if (import_saved_model) {
|
if (saved_model_version == 2) {
|
||||||
std::unordered_set<std::string> tags =
|
|
||||||
absl::StrSplit(saved_model_tags, ',');
|
|
||||||
std::vector<std::string> exported_names =
|
|
||||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
|
||||||
|
|
||||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||||
input_filename, tags, absl::Span<std::string>(exported_names), context);
|
input_filename, tags, exported_names, context);
|
||||||
if (!module)
|
if (!module)
|
||||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||||
|
|
||||||
return module;
|
return module;
|
||||||
} else if (import_saved_model_v1) {
|
} else if (saved_model_version == 1) {
|
||||||
std::unordered_set<std::string> tags =
|
|
||||||
absl::StrSplit(saved_model_tags, ',');
|
|
||||||
|
|
||||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||||
input_filename, tags, context);
|
input_filename, tags, context);
|
||||||
|
|
||||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
@ -42,9 +45,9 @@ LoadFromGraphdefOrMlirSource(
|
|||||||
|
|
||||||
// Load Saved model (either v1 or v2) into MLIR.
|
// Load Saved model (either v1 or v2) into MLIR.
|
||||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||||
bool import_saved_model, bool import_saved_model_v1,
|
const std::string& input_filename, const int saved_model_version,
|
||||||
const std::string& input_filename, const std::string& saved_model_tags,
|
const std::unordered_set<std::string>& tags,
|
||||||
const std::string& saved_model_exported_names, mlir::MLIRContext* context);
|
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
||||||
|
|
||||||
// Taking a MLIR module in TF executor dialect and a set of parameters,
|
// Taking a MLIR module in TF executor dialect and a set of parameters,
|
||||||
// applies a set of passes to convert the module to TF Lite dialect and
|
// applies a set of passes to convert the module to TF Lite dialect and
|
||||||
|
@ -257,7 +257,11 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
target_ops=None,
|
target_ops=None,
|
||||||
allow_nonexistent_arrays=False,
|
allow_nonexistent_arrays=False,
|
||||||
debug_info=None,
|
debug_info=None,
|
||||||
conversion_summary_dir=None):
|
conversion_summary_dir=None,
|
||||||
|
saved_model_dir=None,
|
||||||
|
saved_model_version=0,
|
||||||
|
saved_model_tags=None,
|
||||||
|
saved_model_exported_names=None):
|
||||||
"""Builds protocol buffers describing a conversion of a model using TOCO.
|
"""Builds protocol buffers describing a conversion of a model using TOCO.
|
||||||
|
|
||||||
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
|
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
|
||||||
@ -323,6 +327,18 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
debug_info: `GraphDebugInfo` proto containing the stack traces for the
|
debug_info: `GraphDebugInfo` proto containing the stack traces for the
|
||||||
original nodes referred by the converted graph.
|
original nodes referred by the converted graph.
|
||||||
conversion_summary_dir: A string, the path to the generated conversion logs.
|
conversion_summary_dir: A string, the path to the generated conversion logs.
|
||||||
|
saved_model_dir: Filepath of the saved model to be converted. This value
|
||||||
|
will be non-empty only when the saved model import path will be used.
|
||||||
|
Otherwises, the graph def-based conversion will be processed.
|
||||||
|
saved_model_version: SavedModel file format version of The saved model file
|
||||||
|
to be converted. This value will be set only when the SavedModel import
|
||||||
|
path will be used.
|
||||||
|
saved_model_tags: Set of string saved model tags, formatted in the
|
||||||
|
comma-separated value. This value will be set only when the SavedModel
|
||||||
|
import path will be used.
|
||||||
|
saved_model_exported_names: Names to be exported (default: export all) when
|
||||||
|
the saved model import path is on. This value will be set only when the
|
||||||
|
SavedModel import path will be used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
model_flags, toco_flags, debug_info: three protocol buffers describing the
|
model_flags, toco_flags, debug_info: three protocol buffers describing the
|
||||||
@ -397,6 +413,14 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
|
|
||||||
model.allow_nonexistent_arrays = allow_nonexistent_arrays
|
model.allow_nonexistent_arrays = allow_nonexistent_arrays
|
||||||
|
|
||||||
|
if saved_model_dir:
|
||||||
|
model.saved_model_dir = saved_model_dir
|
||||||
|
model.saved_model_version = saved_model_version
|
||||||
|
if saved_model_tags:
|
||||||
|
model.saved_model_tags.extend(saved_model_tags)
|
||||||
|
if saved_model_exported_names:
|
||||||
|
model.saved_model_exported_names.extend(saved_model_exported_names)
|
||||||
|
|
||||||
return model, toco, debug_info
|
return model, toco, debug_info
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,6 +74,7 @@ from tensorflow.python.lib.io import file_io as _file_io
|
|||||||
from tensorflow.python.saved_model import signature_constants as _signature_constants
|
from tensorflow.python.saved_model import signature_constants as _signature_constants
|
||||||
from tensorflow.python.saved_model import tag_constants as _tag_constants
|
from tensorflow.python.saved_model import tag_constants as _tag_constants
|
||||||
from tensorflow.python.saved_model.load import load as _load
|
from tensorflow.python.saved_model.load import load as _load
|
||||||
|
from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
|
||||||
from tensorflow.python.util import deprecation as _deprecation
|
from tensorflow.python.util import deprecation as _deprecation
|
||||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||||
|
|
||||||
@ -285,6 +286,10 @@ class TFLiteConverterBase(object):
|
|||||||
# The 'GraphDebugInfo' contains the stack traces of all the original nodes
|
# The 'GraphDebugInfo' contains the stack traces of all the original nodes
|
||||||
# in the `GraphDef` to the converter.
|
# in the `GraphDef` to the converter.
|
||||||
self._debug_info = None
|
self._debug_info = None
|
||||||
|
self._saved_model_dir = None
|
||||||
|
self._saved_model_tags = None
|
||||||
|
self._saved_model_version = None
|
||||||
|
self._saved_model_exported_names = []
|
||||||
|
|
||||||
def _grappler_config(self, optimizers=None):
|
def _grappler_config(self, optimizers=None):
|
||||||
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
|
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
|
||||||
@ -346,8 +351,46 @@ class TFLiteConverterBase(object):
|
|||||||
"target_ops": self.target_spec.supported_ops,
|
"target_ops": self.target_spec.supported_ops,
|
||||||
"enable_mlir_converter": self.experimental_new_converter,
|
"enable_mlir_converter": self.experimental_new_converter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self._saved_model_dir:
|
||||||
|
args.update({
|
||||||
|
"saved_model_dir": self._saved_model_dir,
|
||||||
|
"saved_model_version": self._saved_model_version,
|
||||||
|
"saved_model_tags": self._saved_model_tags,
|
||||||
|
"saved_model_exported_names": self._saved_model_exported_names,
|
||||||
|
})
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
def _contains_function_with_implements_attr(self, saved_model_proto):
|
||||||
|
meta_graph = saved_model_proto.meta_graphs[0]
|
||||||
|
for function in meta_graph.graph_def.library.function:
|
||||||
|
if function.attr.get("_implements", None) or function.attr.get(
|
||||||
|
"api_implements", None):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _parse_saved_model_args(self):
|
||||||
|
"""Parses SavedModel arguments from the given Keras/RNN SavedModel."""
|
||||||
|
if self._saved_model_dir:
|
||||||
|
try:
|
||||||
|
saved_model_proto, _ = (
|
||||||
|
_parse_saved_model_with_debug_info(self._saved_model_dir))
|
||||||
|
except OSError:
|
||||||
|
# If it fails to read the given saved model, it will fall back to the
|
||||||
|
# frozen graph def path.
|
||||||
|
self._saved_model_dir = None
|
||||||
|
return
|
||||||
|
if not self._contains_function_with_implements_attr(saved_model_proto):
|
||||||
|
self._saved_model_dir = None
|
||||||
|
else:
|
||||||
|
self._saved_model_exported_names = []
|
||||||
|
self._saved_model_version = saved_model_proto.saved_model_schema_version
|
||||||
|
if self._saved_model_version not in [1, 2]:
|
||||||
|
raise ValueError(
|
||||||
|
"SavedModel file format({0}) is not supported".format(
|
||||||
|
self._saved_model_version))
|
||||||
|
|
||||||
|
|
||||||
@_tf_export("lite.TFLiteConverter", v1=[])
|
@_tf_export("lite.TFLiteConverter", v1=[])
|
||||||
class TFLiteConverterV2(TFLiteConverterBase):
|
class TFLiteConverterV2(TFLiteConverterBase):
|
||||||
@ -387,7 +430,11 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, funcs, trackable_obj=None):
|
def __init__(self,
|
||||||
|
funcs,
|
||||||
|
trackable_obj=None,
|
||||||
|
saved_model_dir=None,
|
||||||
|
saved_model_tags=None):
|
||||||
"""Constructor for TFLiteConverter.
|
"""Constructor for TFLiteConverter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -398,10 +445,19 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
get garbage collected since functions have a weak reference to
|
get garbage collected since functions have a weak reference to
|
||||||
Variables. This is only required when the tf.AutoTrackable object is not
|
Variables. This is only required when the tf.AutoTrackable object is not
|
||||||
maintained by the user (e.g. `from_saved_model`).
|
maintained by the user (e.g. `from_saved_model`).
|
||||||
|
saved_model_dir: Directory of the SavedModel. This argument can be null
|
||||||
|
when it creates via the from_keras_model and from_concrete_function
|
||||||
|
methods.
|
||||||
|
saved_model_tags: Set of tags identifying the MetaGraphDef within the
|
||||||
|
SavedModel to analyze. All tags in the tag set must be present. (default
|
||||||
|
set(SERVING)). This argument will be available when the saved model dir
|
||||||
|
argument is set.
|
||||||
"""
|
"""
|
||||||
super(TFLiteConverterV2, self).__init__()
|
super(TFLiteConverterV2, self).__init__()
|
||||||
self._funcs = funcs
|
self._funcs = funcs
|
||||||
self._trackable_obj = trackable_obj
|
self._trackable_obj = trackable_obj
|
||||||
|
self._saved_model_dir = saved_model_dir
|
||||||
|
self._saved_model_tags = saved_model_tags
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_concrete_functions(cls, funcs):
|
def from_concrete_functions(cls, funcs):
|
||||||
@ -463,6 +519,9 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
|
|
||||||
# Ensures any graphs created in Eager mode are able to run. This is required
|
# Ensures any graphs created in Eager mode are able to run. This is required
|
||||||
# in order to create a tf.estimator.Exporter that exports a TFLite model.
|
# in order to create a tf.estimator.Exporter that exports a TFLite model.
|
||||||
|
if tags is None:
|
||||||
|
tags = set([_tag_constants.SERVING])
|
||||||
|
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
saved_model = _load(saved_model_dir, tags)
|
saved_model = _load(saved_model_dir, tags)
|
||||||
if not signature_keys:
|
if not signature_keys:
|
||||||
@ -475,7 +534,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
"'{}'.".format(key, ",".join(saved_model.signatures)))
|
"'{}'.".format(key, ",".join(saved_model.signatures)))
|
||||||
funcs.append(saved_model.signatures[key])
|
funcs.append(saved_model.signatures[key])
|
||||||
|
|
||||||
return cls(funcs, saved_model)
|
return cls(funcs, saved_model, saved_model_dir, tags)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_keras_model(cls, model):
|
def from_keras_model(cls, model):
|
||||||
@ -521,6 +580,9 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
"ConcreteFunction. Converting multiple functions is "
|
"ConcreteFunction. Converting multiple functions is "
|
||||||
"under development.")
|
"under development.")
|
||||||
|
|
||||||
|
# Parses SavedModel argument.
|
||||||
|
self._parse_saved_model_args()
|
||||||
|
|
||||||
# graph_def is used here to preserve the node bug information
|
# graph_def is used here to preserve the node bug information
|
||||||
frozen_func, graph_def = (
|
frozen_func, graph_def = (
|
||||||
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
|
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
|
||||||
@ -693,6 +755,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
the dataset to evaluate different optimizations.
|
the dataset to evaluate different optimizations.
|
||||||
experimental_new_converter: Experimental flag, subject to change.
|
experimental_new_converter: Experimental flag, subject to change.
|
||||||
Enables MLIR-based conversion instead of TOCO conversion.
|
Enables MLIR-based conversion instead of TOCO conversion.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -725,7 +788,9 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
output_tensors,
|
output_tensors,
|
||||||
input_arrays_with_shape=None,
|
input_arrays_with_shape=None,
|
||||||
output_arrays=None,
|
output_arrays=None,
|
||||||
experimental_debug_info_func=None):
|
experimental_debug_info_func=None,
|
||||||
|
saved_model_dir=None,
|
||||||
|
saved_model_tags=None):
|
||||||
"""Constructor for TFLiteConverter.
|
"""Constructor for TFLiteConverter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -743,6 +808,13 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
`output_tensors` are None. (default None)
|
`output_tensors` are None. (default None)
|
||||||
experimental_debug_info_func: An experimental function to retrieve the
|
experimental_debug_info_func: An experimental function to retrieve the
|
||||||
graph debug info for a set of nodes from the `graph_def`.
|
graph debug info for a set of nodes from the `graph_def`.
|
||||||
|
saved_model_dir: Directory of the SavedModel. This argument can be null
|
||||||
|
when it creates via the from_keras_model and from_concrete_function
|
||||||
|
methods.
|
||||||
|
saved_model_tags: Set of tags identifying the MetaGraphDef within the
|
||||||
|
SavedModel to analyze. All tags in the tag set must be present. (default
|
||||||
|
set(SERVING)). This argument will be available when the saved model dir
|
||||||
|
argument is set.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: Invalid arguments.
|
ValueError: Invalid arguments.
|
||||||
@ -766,6 +838,8 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
self.conversion_summary_dir = None
|
self.conversion_summary_dir = None
|
||||||
self._debug_info_func = experimental_debug_info_func
|
self._debug_info_func = experimental_debug_info_func
|
||||||
self._custom_opdefs = None
|
self._custom_opdefs = None
|
||||||
|
self._saved_model_dir = saved_model_dir
|
||||||
|
self._saved_model_tags = saved_model_tags
|
||||||
|
|
||||||
# Attributes are used by models that cannot be loaded into TensorFlow.
|
# Attributes are used by models that cannot be loaded into TensorFlow.
|
||||||
if not self._has_valid_tensors():
|
if not self._has_valid_tensors():
|
||||||
@ -928,7 +1002,9 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
graph_def=result[0],
|
graph_def=result[0],
|
||||||
input_tensors=result[1],
|
input_tensors=result[1],
|
||||||
output_tensors=result[2],
|
output_tensors=result[2],
|
||||||
experimental_debug_info_func=_build_debug_info_func(result[3]))
|
experimental_debug_info_func=_build_debug_info_func(result[3]),
|
||||||
|
saved_model_dir=saved_model_dir,
|
||||||
|
saved_model_tags=tag_set)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_keras_model_file(cls,
|
def from_keras_model_file(cls,
|
||||||
@ -1059,6 +1135,9 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
Input shape is not specified.
|
Input shape is not specified.
|
||||||
None value for dimension in input_tensor.
|
None value for dimension in input_tensor.
|
||||||
"""
|
"""
|
||||||
|
# Parses SavedModel argument.
|
||||||
|
self._parse_saved_model_args()
|
||||||
|
|
||||||
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
|
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
|
||||||
self.representative_dataset, self._graph_def)
|
self.representative_dataset, self._graph_def)
|
||||||
|
|
||||||
|
@ -12,10 +12,11 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
import "tensorflow/lite/toco/types.proto";
|
|
||||||
|
|
||||||
package toco;
|
package toco;
|
||||||
|
|
||||||
|
import "tensorflow/lite/toco/types.proto";
|
||||||
|
|
||||||
message InputArrayShape {
|
message InputArrayShape {
|
||||||
repeated int32 dims = 2;
|
repeated int32 dims = 2;
|
||||||
}
|
}
|
||||||
@ -130,7 +131,7 @@ message ArraysExtraInfo {
|
|||||||
// optional int32 input_dims = 11 [ default = 4];
|
// optional int32 input_dims = 11 [ default = 4];
|
||||||
// repeated int32 input_shape = 13;
|
// repeated int32 input_shape = 13;
|
||||||
//
|
//
|
||||||
// Next ID to USE: 20.
|
// Next ID to USE: 24.
|
||||||
message ModelFlags {
|
message ModelFlags {
|
||||||
// Information about the input arrays, i.e. the arrays from which input
|
// Information about the input arrays, i.e. the arrays from which input
|
||||||
// activations will be read.
|
// activations will be read.
|
||||||
@ -181,4 +182,22 @@ message ModelFlags {
|
|||||||
// When set to false, toco will not change the input ranges and the output
|
// When set to false, toco will not change the input ranges and the output
|
||||||
// ranges of concat operator to the overlap of all input ranges.
|
// ranges of concat operator to the overlap of all input ranges.
|
||||||
optional bool change_concat_input_ranges = 19 [default = true];
|
optional bool change_concat_input_ranges = 19 [default = true];
|
||||||
|
|
||||||
|
// Filepath of the saved model to be converted. This value will be non-empty
|
||||||
|
// only when the saved model import path will be used. Otherwise, the graph
|
||||||
|
// def-based conversion will be processed.
|
||||||
|
optional string saved_model_dir = 20;
|
||||||
|
|
||||||
|
// SavedModel file format version of The saved model file to be converted.
|
||||||
|
// This value will be set only when the SavedModel import path will be used.
|
||||||
|
optional int32 saved_model_version = 21;
|
||||||
|
|
||||||
|
// Set of string saved model tags, formatted in the comma-separated value.
|
||||||
|
// This value will be set only when the SavedModel import path will be used.
|
||||||
|
repeated string saved_model_tags = 22;
|
||||||
|
|
||||||
|
// Names to be exported (default: export all) when the saved model import path
|
||||||
|
// is on. This value will be set only when the SavedModel import path will be
|
||||||
|
// used.
|
||||||
|
repeated string saved_model_exported_names = 23;
|
||||||
}
|
}
|
||||||
|
@ -49,6 +49,7 @@ cc_library(
|
|||||||
"//tensorflow/lite/toco:tooling_util",
|
"//tensorflow/lite/toco:tooling_util",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer",
|
"//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer",
|
||||||
|
"//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer",
|
||||||
] + select({
|
] + select({
|
||||||
# This is required when running `tflite_convert` from `bazel`.
|
# This is required when running `tflite_convert` from `bazel`.
|
||||||
# It requires to link with TensorFlow Ops to get the op definitions.
|
# It requires to link with TensorFlow Ops to get the op definitions.
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "google/protobuf/text_format.h"
|
#include "google/protobuf/text_format.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
|
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||||
#include "tensorflow/lite/toco/import_tensorflow.h"
|
#include "tensorflow/lite/toco/import_tensorflow.h"
|
||||||
@ -144,13 +145,6 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::GraphDef graph_def;
|
|
||||||
if (!graph_def.ParseFromString(input_contents_txt)) {
|
|
||||||
PyErr_SetString(PyExc_ValueError,
|
|
||||||
"Failed to convert GraphDef to Python String.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& dump_options = *GraphVizDumpOptions::singleton();
|
auto& dump_options = *GraphVizDumpOptions::singleton();
|
||||||
if (toco_flags.has_dump_graphviz_dir()) {
|
if (toco_flags.has_dump_graphviz_dir()) {
|
||||||
dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
|
dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
|
||||||
@ -165,13 +159,25 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
|||||||
|
|
||||||
// Convert model.
|
// Convert model.
|
||||||
if (enable_mlir_converter) {
|
if (enable_mlir_converter) {
|
||||||
status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
|
if (!model_flags.saved_model_dir().empty()) {
|
||||||
model_flags, toco_flags, debug_info, graph_def,
|
status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer(
|
||||||
&output_file_contents_txt);
|
model_flags, toco_flags, &output_file_contents_txt);
|
||||||
if (!toco_flags.conversion_summary_dir().empty()) {
|
} else {
|
||||||
PopulateConversionLogHelper(model_flags, &toco_flags, input_contents_txt,
|
tensorflow::GraphDef graph_def;
|
||||||
output_file_contents_txt,
|
if (!graph_def.ParseFromString(input_contents_txt)) {
|
||||||
status.error_message(), &dump_options);
|
PyErr_SetString(PyExc_ValueError,
|
||||||
|
"Failed to convert GraphDef to Python String.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
|
||||||
|
model_flags, toco_flags, debug_info, graph_def,
|
||||||
|
&output_file_contents_txt);
|
||||||
|
if (!toco_flags.conversion_summary_dir().empty()) {
|
||||||
|
PopulateConversionLogHelper(
|
||||||
|
model_flags, &toco_flags, input_contents_txt,
|
||||||
|
output_file_contents_txt, status.error_message(), &dump_options);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
status = Convert(input_contents_txt, toco_flags, model_flags,
|
status = Convert(input_contents_txt, toco_flags, model_flags,
|
||||||
|
@ -5,7 +5,7 @@ tf_class {
|
|||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "convert"
|
name: "convert"
|
||||||
|
@ -5,7 +5,7 @@ tf_class {
|
|||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'funcs\', \'trackable_obj\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'funcs\', \'trackable_obj\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "convert"
|
name: "convert"
|
||||||
|
Loading…
Reference in New Issue
Block a user