Refactor keras dependency code
PiperOrigin-RevId: 339954904 Change-Id: Id9f6717da5c32bff185a10a37e9682be64cc6501
This commit is contained in:
parent
431b12d9c5
commit
0b752d9223
@ -149,7 +149,6 @@ py_library(
|
||||
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
|
||||
"//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py",
|
||||
"//tensorflow/lite/experimental/tensorboard:ops_util",
|
||||
"//tensorflow/lite/python/keras/saving:saving_utils",
|
||||
"//tensorflow/lite/python/optimize:calibrator",
|
||||
"//tensorflow/python:graph_util",
|
||||
"//tensorflow/python/keras",
|
||||
@ -236,6 +235,7 @@ py_library(
|
||||
":op_hint",
|
||||
":schema_py",
|
||||
":schema_util",
|
||||
"//tensorflow/lite/python:tflite_keras_util",
|
||||
"//tensorflow/lite/toco:toco_flags_proto_py",
|
||||
"//tensorflow/python:convert_to_constants",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -278,6 +278,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tflite_keras_util",
|
||||
srcs = [
|
||||
"tflite_keras_util.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "wrap_toco",
|
||||
srcs = [
|
||||
|
@ -1,16 +0,0 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saving_utils",
|
||||
srcs = [
|
||||
"saving_utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
],
|
||||
)
|
@ -50,7 +50,6 @@ from tensorflow.lite.python.convert import toco_convert_protos # pylint: disabl
|
||||
from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
|
||||
from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.keras.saving import saving_utils as _keras_saving_utils
|
||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
|
||||
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
|
||||
@ -63,9 +62,11 @@ from tensorflow.lite.python.util import get_grappler_config as _get_grappler_con
|
||||
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
|
||||
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
|
||||
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
||||
from tensorflow.lite.python.util import model_input_signature as _model_input_signature
|
||||
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
|
||||
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
|
||||
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
|
||||
from tensorflow.lite.python.util import trace_model_call as _trace_model_call
|
||||
from tensorflow.python.client import session as _session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function as _def_function
|
||||
@ -839,12 +840,11 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
|
||||
# Pass `keep_original_batch_size=True` will ensure that we get an input
|
||||
# signature including the batch dimension specified by the user.
|
||||
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
|
||||
input_signature = _keras_saving_utils.model_input_signature(
|
||||
input_signature = _model_input_signature(
|
||||
self._keras_model, keep_original_batch_size=True)
|
||||
|
||||
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
|
||||
func = _keras_saving_utils.trace_model_call(
|
||||
self._keras_model, input_signature)
|
||||
func = _trace_model_call(self._keras_model, input_signature)
|
||||
concrete_func = func.get_concrete_function()
|
||||
self._funcs = [concrete_func]
|
||||
|
||||
@ -1468,7 +1468,7 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
||||
|
||||
keras_model = keras_deps.get_load_model_function()(
|
||||
model_file, custom_objects)
|
||||
function = _keras_saving_utils.trace_model_call(keras_model)
|
||||
function = _trace_model_call(keras_model)
|
||||
concrete_func = function.get_concrete_function()
|
||||
|
||||
frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
|
||||
|
@ -13,7 +13,13 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Utility functions for TensorFlow models."""
|
||||
"""Keras functions required by TensorFlow Lite.
|
||||
|
||||
The functions defined in this library have been copied over from Keras in order
|
||||
to remove the dependency from TensorFlow Lite to Keras. The functions which
|
||||
could not be copied over are accessed using the dependecy inversion principle.
|
||||
(for details, refer to tensorflow/python/util/keras_deps.py).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
@ -33,6 +33,7 @@ from tensorflow.core.protobuf import graph_debug_info_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||
from tensorflow.lite.python import schema_util
|
||||
from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util
|
||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
||||
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
|
||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||
@ -44,6 +45,10 @@ from tensorflow.python.framework import graph_util as tf_graph_util
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
|
||||
|
||||
# Keras functions used by TFLite
|
||||
model_input_signature = _tflite_keras_util.model_input_signature
|
||||
trace_model_call = _tflite_keras_util.trace_model_call
|
||||
|
||||
# Map of tf.dtypes to TFLite types_flag_pb2.
|
||||
_MAP_TF_TO_TFLITE_TYPES = {
|
||||
dtypes.float32: _types_pb2.FLOAT,
|
||||
|
@ -417,6 +417,8 @@ def call_context():
|
||||
return call_ctx
|
||||
|
||||
|
||||
# Inject the call_context function to keras_deps to remove the dependency
|
||||
# from TFLite to Keras.
|
||||
keras_deps.register_call_context_function(call_context)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user