diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index e4eeaebe900..c48220ef7d8 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -128,6 +128,7 @@ py_library( "//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops", "//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py", "//tensorflow/lite/experimental/tensorboard:ops_util", + "//tensorflow/lite/python/keras/saving:saving_utils", "//tensorflow/lite/python/optimize:calibrator", "//tensorflow/python:graph_util", "//tensorflow/python/keras", diff --git a/tensorflow/lite/python/keras/saving/BUILD b/tensorflow/lite/python/keras/saving/BUILD new file mode 100644 index 00000000000..ff5c679a527 --- /dev/null +++ b/tensorflow/lite/python/keras/saving/BUILD @@ -0,0 +1,15 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +py_library( + name = "saving_utils", + srcs = [ + "saving_utils.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:util", + ], +) diff --git a/tensorflow/lite/python/keras/saving/saving_utils.py b/tensorflow/lite/python/keras/saving/saving_utils.py new file mode 100644 index 00000000000..03a442d2ee3 --- /dev/null +++ b/tensorflow/lite/python/keras/saving/saving_utils.py @@ -0,0 +1,83 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utility functions for TensorFlow models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +from tensorflow.python.util import nest +from tensorflow.python.util.compat import collections_abc + + +def _enforce_names_consistency(specs): + """Enforces that either all specs have names or none do.""" + + def _has_name(spec): + return hasattr(spec, 'name') and spec.name is not None + + def _clear_name(spec): + spec = copy.deepcopy(spec) + if hasattr(spec, 'name'): + spec._name = None # pylint:disable=protected-access + return spec + + flat_specs = nest.flatten(specs) + name_inconsistency = ( + any(_has_name(s) for s in flat_specs) and + not all(_has_name(s) for s in flat_specs)) + + if name_inconsistency: + specs = nest.map_structure(_clear_name, specs) + return specs + + +def model_input_signature(model, keep_original_batch_size=False): + """Inspect model to get its input signature. + + The model's input signature is a list with a single (possibly-nested) object. + This is due to the Keras-enforced restriction that tensor inputs must be + passed in as the first argument. + + For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>} + will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}] + + Args: + model: Keras Model object. + keep_original_batch_size: A boolean indicating whether we want to keep using + the original batch size or set it to None. Default is `False`, which means + that the batch dim of the returned input signature will always be set to + `None`. + + Returns: + A list containing either a single TensorSpec or an object with nested + TensorSpecs. This list does not contain the `training` argument. + """ + input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size) # pylint: disable=protected-access + if input_specs is None: + return None + input_specs = _enforce_names_consistency(input_specs) + # Return a list with a single element as the model's input signature. + if isinstance(input_specs, + collections_abc.Sequence) and len(input_specs) == 1: + # Note that the isinstance check filters out single-element dictionaries, + # which should also be wrapped as a single-element list. + return input_specs + else: + return [input_specs] + diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 1d30f50c155..d518d8675d8 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -50,6 +50,7 @@ from tensorflow.lite.python.convert import toco_convert_protos # pylint: disabl from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import +from tensorflow.lite.python.keras.saving import saving_utils as _keras_saving_utils from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import @@ -858,7 +859,8 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2): if not isinstance(self._keras_model.call, _def_function.Function): # Pass `keep_original_batch_size=True` will ensure that we get an input # signature including the batch dimension specified by the user. - input_signature = _saving_utils.model_input_signature( + # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF + input_signature = _keras_saving_utils.model_input_signature( self._keras_model, keep_original_batch_size=True) func = _saving_utils.trace_model_call(self._keras_model, input_signature)