Remove Keras dependency from TFLite (lite.py)

PiperOrigin-RevId: 335700329
Change-Id: Ie7436b2cb8746c11e427006573b2835a2ca058cd
This commit is contained in:
Meghna Natraj 2020-10-06 12:50:32 -07:00 committed by TensorFlower Gardener
parent 96fb44b7b4
commit f7dfcc3468
4 changed files with 102 additions and 1 deletions

View File

@ -128,6 +128,7 @@ py_library(
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
"//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py",
"//tensorflow/lite/experimental/tensorboard:ops_util",
"//tensorflow/lite/python/keras/saving:saving_utils",
"//tensorflow/lite/python/optimize:calibrator",
"//tensorflow/python:graph_util",
"//tensorflow/python/keras",

View File

@ -0,0 +1,15 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "saving_utils",
srcs = [
"saving_utils.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
],
)

View File

@ -0,0 +1,83 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for TensorFlow models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
def _enforce_names_consistency(specs):
"""Enforces that either all specs have names or none do."""
def _has_name(spec):
return hasattr(spec, 'name') and spec.name is not None
def _clear_name(spec):
spec = copy.deepcopy(spec)
if hasattr(spec, 'name'):
spec._name = None # pylint:disable=protected-access
return spec
flat_specs = nest.flatten(specs)
name_inconsistency = (
any(_has_name(s) for s in flat_specs) and
not all(_has_name(s) for s in flat_specs))
if name_inconsistency:
specs = nest.map_structure(_clear_name, specs)
return specs
def model_input_signature(model, keep_original_batch_size=False):
"""Inspect model to get its input signature.
The model's input signature is a list with a single (possibly-nested) object.
This is due to the Keras-enforced restriction that tensor inputs must be
passed in as the first argument.
For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
Args:
model: Keras Model object.
keep_original_batch_size: A boolean indicating whether we want to keep using
the original batch size or set it to None. Default is `False`, which means
that the batch dim of the returned input signature will always be set to
`None`.
Returns:
A list containing either a single TensorSpec or an object with nested
TensorSpecs. This list does not contain the `training` argument.
"""
input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size) # pylint: disable=protected-access
if input_specs is None:
return None
input_specs = _enforce_names_consistency(input_specs)
# Return a list with a single element as the model's input signature.
if isinstance(input_specs,
collections_abc.Sequence) and len(input_specs) == 1:
# Note that the isinstance check filters out single-element dictionaries,
# which should also be wrapped as a single-element list.
return input_specs
else:
return [input_specs]

View File

@ -50,6 +50,7 @@ from tensorflow.lite.python.convert import toco_convert_protos # pylint: disabl
from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import
from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import
from tensorflow.lite.python.keras.saving import saving_utils as _keras_saving_utils
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
@ -858,7 +859,8 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
if not isinstance(self._keras_model.call, _def_function.Function):
# Pass `keep_original_batch_size=True` will ensure that we get an input
# signature including the batch dimension specified by the user.
input_signature = _saving_utils.model_input_signature(
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
input_signature = _keras_saving_utils.model_input_signature(
self._keras_model, keep_original_batch_size=True)
func = _saving_utils.trace_model_call(self._keras_model, input_signature)