Remove Keras dependency from TFLite (lite.py)
PiperOrigin-RevId: 335700329 Change-Id: Ie7436b2cb8746c11e427006573b2835a2ca058cd
This commit is contained in:
parent
96fb44b7b4
commit
f7dfcc3468
@ -128,6 +128,7 @@ py_library(
|
||||
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
|
||||
"//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py",
|
||||
"//tensorflow/lite/experimental/tensorboard:ops_util",
|
||||
"//tensorflow/lite/python/keras/saving:saving_utils",
|
||||
"//tensorflow/lite/python/optimize:calibrator",
|
||||
"//tensorflow/python:graph_util",
|
||||
"//tensorflow/python/keras",
|
||||
|
15
tensorflow/lite/python/keras/saving/BUILD
Normal file
15
tensorflow/lite/python/keras/saving/BUILD
Normal file
@ -0,0 +1,15 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saving_utils",
|
||||
srcs = [
|
||||
"saving_utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
83
tensorflow/lite/python/keras/saving/saving_utils.py
Normal file
83
tensorflow/lite/python/keras/saving/saving_utils.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Utility functions for TensorFlow models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
|
||||
|
||||
def _enforce_names_consistency(specs):
|
||||
"""Enforces that either all specs have names or none do."""
|
||||
|
||||
def _has_name(spec):
|
||||
return hasattr(spec, 'name') and spec.name is not None
|
||||
|
||||
def _clear_name(spec):
|
||||
spec = copy.deepcopy(spec)
|
||||
if hasattr(spec, 'name'):
|
||||
spec._name = None # pylint:disable=protected-access
|
||||
return spec
|
||||
|
||||
flat_specs = nest.flatten(specs)
|
||||
name_inconsistency = (
|
||||
any(_has_name(s) for s in flat_specs) and
|
||||
not all(_has_name(s) for s in flat_specs))
|
||||
|
||||
if name_inconsistency:
|
||||
specs = nest.map_structure(_clear_name, specs)
|
||||
return specs
|
||||
|
||||
|
||||
def model_input_signature(model, keep_original_batch_size=False):
|
||||
"""Inspect model to get its input signature.
|
||||
|
||||
The model's input signature is a list with a single (possibly-nested) object.
|
||||
This is due to the Keras-enforced restriction that tensor inputs must be
|
||||
passed in as the first argument.
|
||||
|
||||
For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
|
||||
will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
|
||||
|
||||
Args:
|
||||
model: Keras Model object.
|
||||
keep_original_batch_size: A boolean indicating whether we want to keep using
|
||||
the original batch size or set it to None. Default is `False`, which means
|
||||
that the batch dim of the returned input signature will always be set to
|
||||
`None`.
|
||||
|
||||
Returns:
|
||||
A list containing either a single TensorSpec or an object with nested
|
||||
TensorSpecs. This list does not contain the `training` argument.
|
||||
"""
|
||||
input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size) # pylint: disable=protected-access
|
||||
if input_specs is None:
|
||||
return None
|
||||
input_specs = _enforce_names_consistency(input_specs)
|
||||
# Return a list with a single element as the model's input signature.
|
||||
if isinstance(input_specs,
|
||||
collections_abc.Sequence) and len(input_specs) == 1:
|
||||
# Note that the isinstance check filters out single-element dictionaries,
|
||||
# which should also be wrapped as a single-element list.
|
||||
return input_specs
|
||||
else:
|
||||
return [input_specs]
|
||||
|
@ -50,6 +50,7 @@ from tensorflow.lite.python.convert import toco_convert_protos # pylint: disabl
|
||||
from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
|
||||
from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.keras.saving import saving_utils as _keras_saving_utils
|
||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
|
||||
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
|
||||
@ -858,7 +859,8 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
|
||||
if not isinstance(self._keras_model.call, _def_function.Function):
|
||||
# Pass `keep_original_batch_size=True` will ensure that we get an input
|
||||
# signature including the batch dimension specified by the user.
|
||||
input_signature = _saving_utils.model_input_signature(
|
||||
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
|
||||
input_signature = _keras_saving_utils.model_input_signature(
|
||||
self._keras_model, keep_original_batch_size=True)
|
||||
|
||||
func = _saving_utils.trace_model_call(self._keras_model, input_signature)
|
||||
|
Loading…
Reference in New Issue
Block a user