Enable Keras/RNN case via MLIR SavedModel import in TFLiteConverterV2

PiperOrigin-RevId: 304694033
Change-Id: I3c2586b92e1b4a810036ed390cb5b4d83352bef8
This commit is contained in:
Jaesung Chung 2020-04-03 14:29:59 -07:00 committed by TensorFlower Gardener
parent 9a85d6fe42
commit f21e640f0e
16 changed files with 213 additions and 62 deletions

View File

@ -771,6 +771,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",

View File

@ -84,8 +84,14 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
mlir::TFL::PassConfig pass_config(quant_specs);
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
pass_config.shape_inference = false;
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
quant_specs, result);
pass_config, result);
}
} // namespace tensorflow

View File

@ -43,8 +43,6 @@ namespace tensorflow {
Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
const string& saved_model_dir, bool saved_model_v1,
const string& saved_model_tags, const string& saved_model_exported_names,
string* result) {
mlir::MLIRContext context;
mlir::TFL::QuantizationSpecs quant_specs;
@ -66,13 +64,28 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
// Register all custom ops, including user-specified custom ops.
TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));
const bool import_saved_model = !saved_model_v1;
TF_ASSIGN_OR_RETURN(
auto module,
ImportSavedModel(import_saved_model, saved_model_v1, saved_model_dir,
saved_model_tags, saved_model_exported_names, &context));
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
quant_specs, result);
auto& saved_model_tags = model_flags.saved_model_tags();
auto& saved_model_exported_names = model_flags.saved_model_exported_names();
std::unordered_set<std::string> tags(saved_model_tags.begin(),
saved_model_tags.end());
auto exported_names_in_vector = std::vector<std::string>(
saved_model_exported_names.begin(), saved_model_exported_names.end());
absl::Span<std::string> exported_names(exported_names_in_vector);
TF_ASSIGN_OR_RETURN(auto module,
ImportSavedModel(model_flags.saved_model_dir(),
model_flags.saved_model_version(), tags,
exported_names, &context));
mlir::TFL::PassConfig pass_config(quant_specs);
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
pass_config.shape_inference = true;
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, result);
return status;
}
} // namespace tensorflow

View File

@ -28,8 +28,6 @@ namespace tensorflow {
// status if it fails to convert the input.
Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
const string& saved_model_dir, bool saved_model_v1,
const string& saved_model_tags, const string& saved_model_exported_names,
string* result);
} // namespace tensorflow

View File

@ -261,7 +261,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module,
mlir::TFL::QuantizationSpecs quant_specs,
const mlir::TFL::PassConfig& pass_config,
string* result) {
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
@ -275,9 +275,6 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
}
mlir::PassManager pm(module->getContext());
mlir::TFL::PassConfig pass_config(quant_specs);
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
// Convert back to outlined while format for export back to flatbuffer.
@ -288,7 +285,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm);
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result,
&pm);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.

View File

@ -47,7 +47,7 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
// This will also run relevant passes as well.
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module,
mlir::TFL::QuantizationSpecs quant_specs,
const mlir::TFL::PassConfig& pass_config,
string* result);
// Give a warning for any unused flags that have been specified.

View File

@ -138,13 +138,24 @@ int main(int argc, char **argv) {
// TODO(b/147435528): We need to test the e2e behavior once the graph freezing
// inside mlir is done.
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
int saved_model_version;
if (import_saved_model_object_graph) {
saved_model_version = 2;
} else {
saved_model_version = 1;
}
if (input_mlir)
module = tensorflow::errors::InvalidArgument(
"Importing saved model should not have input_mlir set");
module = tensorflow::ImportSavedModel(import_saved_model_object_graph,
import_saved_model_signature_defs,
input_file_name, saved_model_tags,
saved_model_exported_names, &context);
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names_vector =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_vector);
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
tags, exported_names, &context);
} else {
module = tensorflow::LoadFromGraphdefOrMlirSource(
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,

View File

@ -160,25 +160,17 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
}
StatusOr<mlir::OwningModuleRef> ImportSavedModel(
bool import_saved_model, bool import_saved_model_v1,
const std::string& input_filename, const std::string& saved_model_tags,
const std::string& saved_model_exported_names, mlir::MLIRContext* context) {
if (import_saved_model) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
const std::string& input_filename, const int saved_model_version,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
if (saved_model_version == 2) {
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, absl::Span<std::string>(exported_names), context);
input_filename, tags, exported_names, context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
return module;
} else if (import_saved_model_v1) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
} else if (saved_model_version == 1) {
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, context);

View File

@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
#include <unordered_set>
#include "absl/types/span.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
@ -42,9 +45,9 @@ LoadFromGraphdefOrMlirSource(
// Load Saved model (either v1 or v2) into MLIR.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
bool import_saved_model, bool import_saved_model_v1,
const std::string& input_filename, const std::string& saved_model_tags,
const std::string& saved_model_exported_names, mlir::MLIRContext* context);
const std::string& input_filename, const int saved_model_version,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
// Taking a MLIR module in TF executor dialect and a set of parameters,
// applies a set of passes to convert the module to TF Lite dialect and

View File

@ -257,7 +257,11 @@ def build_toco_convert_protos(input_tensors,
target_ops=None,
allow_nonexistent_arrays=False,
debug_info=None,
conversion_summary_dir=None):
conversion_summary_dir=None,
saved_model_dir=None,
saved_model_version=0,
saved_model_tags=None,
saved_model_exported_names=None):
"""Builds protocol buffers describing a conversion of a model using TOCO.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@ -323,6 +327,18 @@ def build_toco_convert_protos(input_tensors,
debug_info: `GraphDebugInfo` proto containing the stack traces for the
original nodes referred by the converted graph.
conversion_summary_dir: A string, the path to the generated conversion logs.
saved_model_dir: Filepath of the saved model to be converted. This value
will be non-empty only when the saved model import path will be used.
Otherwises, the graph def-based conversion will be processed.
saved_model_version: SavedModel file format version of The saved model file
to be converted. This value will be set only when the SavedModel import
path will be used.
saved_model_tags: Set of string saved model tags, formatted in the
comma-separated value. This value will be set only when the SavedModel
import path will be used.
saved_model_exported_names: Names to be exported (default: export all) when
the saved model import path is on. This value will be set only when the
SavedModel import path will be used.
Returns:
model_flags, toco_flags, debug_info: three protocol buffers describing the
@ -397,6 +413,14 @@ def build_toco_convert_protos(input_tensors,
model.allow_nonexistent_arrays = allow_nonexistent_arrays
if saved_model_dir:
model.saved_model_dir = saved_model_dir
model.saved_model_version = saved_model_version
if saved_model_tags:
model.saved_model_tags.extend(saved_model_tags)
if saved_model_exported_names:
model.saved_model_exported_names.extend(saved_model_exported_names)
return model, toco, debug_info

View File

@ -74,6 +74,7 @@ from tensorflow.python.lib.io import file_io as _file_io
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
from tensorflow.python.saved_model.load import load as _load
from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
from tensorflow.python.util import deprecation as _deprecation
from tensorflow.python.util.tf_export import tf_export as _tf_export
@ -285,6 +286,10 @@ class TFLiteConverterBase(object):
# The 'GraphDebugInfo' contains the stack traces of all the original nodes
# in the `GraphDef` to the converter.
self._debug_info = None
self._saved_model_dir = None
self._saved_model_tags = None
self._saved_model_version = None
self._saved_model_exported_names = []
def _grappler_config(self, optimizers=None):
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
@ -346,8 +351,46 @@ class TFLiteConverterBase(object):
"target_ops": self.target_spec.supported_ops,
"enable_mlir_converter": self.experimental_new_converter,
}
if self._saved_model_dir:
args.update({
"saved_model_dir": self._saved_model_dir,
"saved_model_version": self._saved_model_version,
"saved_model_tags": self._saved_model_tags,
"saved_model_exported_names": self._saved_model_exported_names,
})
return args
def _contains_function_with_implements_attr(self, saved_model_proto):
meta_graph = saved_model_proto.meta_graphs[0]
for function in meta_graph.graph_def.library.function:
if function.attr.get("_implements", None) or function.attr.get(
"api_implements", None):
return True
return False
def _parse_saved_model_args(self):
"""Parses SavedModel arguments from the given Keras/RNN SavedModel."""
if self._saved_model_dir:
try:
saved_model_proto, _ = (
_parse_saved_model_with_debug_info(self._saved_model_dir))
except OSError:
# If it fails to read the given saved model, it will fall back to the
# frozen graph def path.
self._saved_model_dir = None
return
if not self._contains_function_with_implements_attr(saved_model_proto):
self._saved_model_dir = None
else:
self._saved_model_exported_names = []
self._saved_model_version = saved_model_proto.saved_model_schema_version
if self._saved_model_version not in [1, 2]:
raise ValueError(
"SavedModel file format({0}) is not supported".format(
self._saved_model_version))
@_tf_export("lite.TFLiteConverter", v1=[])
class TFLiteConverterV2(TFLiteConverterBase):
@ -387,7 +430,11 @@ class TFLiteConverterV2(TFLiteConverterBase):
```
"""
def __init__(self, funcs, trackable_obj=None):
def __init__(self,
funcs,
trackable_obj=None,
saved_model_dir=None,
saved_model_tags=None):
"""Constructor for TFLiteConverter.
Args:
@ -398,10 +445,19 @@ class TFLiteConverterV2(TFLiteConverterBase):
get garbage collected since functions have a weak reference to
Variables. This is only required when the tf.AutoTrackable object is not
maintained by the user (e.g. `from_saved_model`).
saved_model_dir: Directory of the SavedModel. This argument can be null
when it creates via the from_keras_model and from_concrete_function
methods.
saved_model_tags: Set of tags identifying the MetaGraphDef within the
SavedModel to analyze. All tags in the tag set must be present. (default
set(SERVING)). This argument will be available when the saved model dir
argument is set.
"""
super(TFLiteConverterV2, self).__init__()
self._funcs = funcs
self._trackable_obj = trackable_obj
self._saved_model_dir = saved_model_dir
self._saved_model_tags = saved_model_tags
@classmethod
def from_concrete_functions(cls, funcs):
@ -463,6 +519,9 @@ class TFLiteConverterV2(TFLiteConverterBase):
# Ensures any graphs created in Eager mode are able to run. This is required
# in order to create a tf.estimator.Exporter that exports a TFLite model.
if tags is None:
tags = set([_tag_constants.SERVING])
with context.eager_mode():
saved_model = _load(saved_model_dir, tags)
if not signature_keys:
@ -475,7 +534,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
"'{}'.".format(key, ",".join(saved_model.signatures)))
funcs.append(saved_model.signatures[key])
return cls(funcs, saved_model)
return cls(funcs, saved_model, saved_model_dir, tags)
@classmethod
def from_keras_model(cls, model):
@ -521,6 +580,9 @@ class TFLiteConverterV2(TFLiteConverterBase):
"ConcreteFunction. Converting multiple functions is "
"under development.")
# Parses SavedModel argument.
self._parse_saved_model_args()
# graph_def is used here to preserve the node bug information
frozen_func, graph_def = (
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
@ -693,6 +755,7 @@ class TFLiteConverter(TFLiteConverterBase):
the dataset to evaluate different optimizations.
experimental_new_converter: Experimental flag, subject to change.
Enables MLIR-based conversion instead of TOCO conversion.
Example usage:
```python
@ -725,7 +788,9 @@ class TFLiteConverter(TFLiteConverterBase):
output_tensors,
input_arrays_with_shape=None,
output_arrays=None,
experimental_debug_info_func=None):
experimental_debug_info_func=None,
saved_model_dir=None,
saved_model_tags=None):
"""Constructor for TFLiteConverter.
Args:
@ -743,6 +808,13 @@ class TFLiteConverter(TFLiteConverterBase):
`output_tensors` are None. (default None)
experimental_debug_info_func: An experimental function to retrieve the
graph debug info for a set of nodes from the `graph_def`.
saved_model_dir: Directory of the SavedModel. This argument can be null
when it creates via the from_keras_model and from_concrete_function
methods.
saved_model_tags: Set of tags identifying the MetaGraphDef within the
SavedModel to analyze. All tags in the tag set must be present. (default
set(SERVING)). This argument will be available when the saved model dir
argument is set.
Raises:
ValueError: Invalid arguments.
@ -766,6 +838,8 @@ class TFLiteConverter(TFLiteConverterBase):
self.conversion_summary_dir = None
self._debug_info_func = experimental_debug_info_func
self._custom_opdefs = None
self._saved_model_dir = saved_model_dir
self._saved_model_tags = saved_model_tags
# Attributes are used by models that cannot be loaded into TensorFlow.
if not self._has_valid_tensors():
@ -928,7 +1002,9 @@ class TFLiteConverter(TFLiteConverterBase):
graph_def=result[0],
input_tensors=result[1],
output_tensors=result[2],
experimental_debug_info_func=_build_debug_info_func(result[3]))
experimental_debug_info_func=_build_debug_info_func(result[3]),
saved_model_dir=saved_model_dir,
saved_model_tags=tag_set)
@classmethod
def from_keras_model_file(cls,
@ -1059,6 +1135,9 @@ class TFLiteConverter(TFLiteConverterBase):
Input shape is not specified.
None value for dimension in input_tensor.
"""
# Parses SavedModel argument.
self._parse_saved_model_args()
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
self.representative_dataset, self._graph_def)

View File

@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
import "tensorflow/lite/toco/types.proto";
package toco;
import "tensorflow/lite/toco/types.proto";
message InputArrayShape {
repeated int32 dims = 2;
}
@ -130,7 +131,7 @@ message ArraysExtraInfo {
// optional int32 input_dims = 11 [ default = 4];
// repeated int32 input_shape = 13;
//
// Next ID to USE: 20.
// Next ID to USE: 24.
message ModelFlags {
// Information about the input arrays, i.e. the arrays from which input
// activations will be read.
@ -181,4 +182,22 @@ message ModelFlags {
// When set to false, toco will not change the input ranges and the output
// ranges of concat operator to the overlap of all input ranges.
optional bool change_concat_input_ranges = 19 [default = true];
// Filepath of the saved model to be converted. This value will be non-empty
// only when the saved model import path will be used. Otherwise, the graph
// def-based conversion will be processed.
optional string saved_model_dir = 20;
// SavedModel file format version of The saved model file to be converted.
// This value will be set only when the SavedModel import path will be used.
optional int32 saved_model_version = 21;
// Set of string saved model tags, formatted in the comma-separated value.
// This value will be set only when the SavedModel import path will be used.
repeated string saved_model_tags = 22;
// Names to be exported (default: export all) when the saved model import path
// is on. This value will be set only when the SavedModel import path will be
// used.
repeated string saved_model_exported_names = 23;
}

View File

@ -49,6 +49,7 @@ cc_library(
"//tensorflow/lite/toco:tooling_util",
"//tensorflow/core:protos_all_cc",
"//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer",
] + select({
# This is required when running `tflite_convert` from `bazel`.
# It requires to link with TensorFlow Ops to get the op definitions.

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "google/protobuf/text_format.h"
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
#include "tensorflow/lite/toco/import_tensorflow.h"
@ -144,13 +145,6 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
}
}
tensorflow::GraphDef graph_def;
if (!graph_def.ParseFromString(input_contents_txt)) {
PyErr_SetString(PyExc_ValueError,
"Failed to convert GraphDef to Python String.");
return nullptr;
}
auto& dump_options = *GraphVizDumpOptions::singleton();
if (toco_flags.has_dump_graphviz_dir()) {
dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
@ -165,13 +159,25 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
// Convert model.
if (enable_mlir_converter) {
status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
model_flags, toco_flags, debug_info, graph_def,
&output_file_contents_txt);
if (!toco_flags.conversion_summary_dir().empty()) {
PopulateConversionLogHelper(model_flags, &toco_flags, input_contents_txt,
output_file_contents_txt,
status.error_message(), &dump_options);
if (!model_flags.saved_model_dir().empty()) {
status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer(
model_flags, toco_flags, &output_file_contents_txt);
} else {
tensorflow::GraphDef graph_def;
if (!graph_def.ParseFromString(input_contents_txt)) {
PyErr_SetString(PyExc_ValueError,
"Failed to convert GraphDef to Python String.");
return nullptr;
}
status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
model_flags, toco_flags, debug_info, graph_def,
&output_file_contents_txt);
if (!toco_flags.conversion_summary_dir().empty()) {
PopulateConversionLogHelper(
model_flags, &toco_flags, input_contents_txt,
output_file_contents_txt, status.error_message(), &dump_options);
}
}
} else {
status = Convert(input_contents_txt, toco_flags, model_flags,

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "convert"

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'funcs\', \'trackable_obj\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'funcs\', \'trackable_obj\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "convert"