Refactor keras dependency code
PiperOrigin-RevId: 339954904 Change-Id: Id9f6717da5c32bff185a10a37e9682be64cc6501
This commit is contained in:
parent
431b12d9c5
commit
0b752d9223
@ -149,7 +149,6 @@ py_library(
|
|||||||
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
|
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
|
||||||
"//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py",
|
"//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py",
|
||||||
"//tensorflow/lite/experimental/tensorboard:ops_util",
|
"//tensorflow/lite/experimental/tensorboard:ops_util",
|
||||||
"//tensorflow/lite/python/keras/saving:saving_utils",
|
|
||||||
"//tensorflow/lite/python/optimize:calibrator",
|
"//tensorflow/lite/python/optimize:calibrator",
|
||||||
"//tensorflow/python:graph_util",
|
"//tensorflow/python:graph_util",
|
||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
@ -236,6 +235,7 @@ py_library(
|
|||||||
":op_hint",
|
":op_hint",
|
||||||
":schema_py",
|
":schema_py",
|
||||||
":schema_util",
|
":schema_util",
|
||||||
|
"//tensorflow/lite/python:tflite_keras_util",
|
||||||
"//tensorflow/lite/toco:toco_flags_proto_py",
|
"//tensorflow/lite/toco:toco_flags_proto_py",
|
||||||
"//tensorflow/python:convert_to_constants",
|
"//tensorflow/python:convert_to_constants",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
@ -278,6 +278,18 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "tflite_keras_util",
|
||||||
|
srcs = [
|
||||||
|
"tflite_keras_util.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "wrap_toco",
|
name = "wrap_toco",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
package(
|
|
||||||
default_visibility = ["//visibility:public"],
|
|
||||||
licenses = ["notice"], # Apache 2.0
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "saving_utils",
|
|
||||||
srcs = [
|
|
||||||
"saving_utils.py",
|
|
||||||
],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/python:util",
|
|
||||||
"//tensorflow/python/eager:def_function",
|
|
||||||
],
|
|
||||||
)
|
|
@ -50,7 +50,6 @@ from tensorflow.lite.python.convert import toco_convert_protos # pylint: disabl
|
|||||||
from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
|
from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
|
||||||
from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import
|
from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import
|
||||||
from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import
|
from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import
|
||||||
from tensorflow.lite.python.keras.saving import saving_utils as _keras_saving_utils
|
|
||||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
|
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
|
||||||
from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
|
from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
|
||||||
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
|
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
|
||||||
@ -63,9 +62,11 @@ from tensorflow.lite.python.util import get_grappler_config as _get_grappler_con
|
|||||||
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
|
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
|
||||||
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
|
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
|
||||||
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
||||||
|
from tensorflow.lite.python.util import model_input_signature as _model_input_signature
|
||||||
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
|
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
|
||||||
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
|
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
|
||||||
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
|
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
|
||||||
|
from tensorflow.lite.python.util import trace_model_call as _trace_model_call
|
||||||
from tensorflow.python.client import session as _session
|
from tensorflow.python.client import session as _session
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function as _def_function
|
from tensorflow.python.eager import def_function as _def_function
|
||||||
@ -839,12 +840,11 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
|
|||||||
# Pass `keep_original_batch_size=True` will ensure that we get an input
|
# Pass `keep_original_batch_size=True` will ensure that we get an input
|
||||||
# signature including the batch dimension specified by the user.
|
# signature including the batch dimension specified by the user.
|
||||||
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
|
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
|
||||||
input_signature = _keras_saving_utils.model_input_signature(
|
input_signature = _model_input_signature(
|
||||||
self._keras_model, keep_original_batch_size=True)
|
self._keras_model, keep_original_batch_size=True)
|
||||||
|
|
||||||
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
|
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
|
||||||
func = _keras_saving_utils.trace_model_call(
|
func = _trace_model_call(self._keras_model, input_signature)
|
||||||
self._keras_model, input_signature)
|
|
||||||
concrete_func = func.get_concrete_function()
|
concrete_func = func.get_concrete_function()
|
||||||
self._funcs = [concrete_func]
|
self._funcs = [concrete_func]
|
||||||
|
|
||||||
@ -1468,7 +1468,7 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
|||||||
|
|
||||||
keras_model = keras_deps.get_load_model_function()(
|
keras_model = keras_deps.get_load_model_function()(
|
||||||
model_file, custom_objects)
|
model_file, custom_objects)
|
||||||
function = _keras_saving_utils.trace_model_call(keras_model)
|
function = _trace_model_call(keras_model)
|
||||||
concrete_func = function.get_concrete_function()
|
concrete_func = function.get_concrete_function()
|
||||||
|
|
||||||
frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
|
frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
|
||||||
|
@ -13,7 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Utility functions for TensorFlow models."""
|
"""Keras functions required by TensorFlow Lite.
|
||||||
|
|
||||||
|
The functions defined in this library have been copied over from Keras in order
|
||||||
|
to remove the dependency from TensorFlow Lite to Keras. The functions which
|
||||||
|
could not be copied over are accessed using the dependecy inversion principle.
|
||||||
|
(for details, refer to tensorflow/python/util/keras_deps.py).
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
@ -33,6 +33,7 @@ from tensorflow.core.protobuf import graph_debug_info_pb2
|
|||||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||||
from tensorflow.lite.python import schema_util
|
from tensorflow.lite.python import schema_util
|
||||||
|
from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util
|
||||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
||||||
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
|
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
|
||||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||||
@ -44,6 +45,10 @@ from tensorflow.python.framework import graph_util as tf_graph_util
|
|||||||
from tensorflow.python.grappler import tf_optimizer
|
from tensorflow.python.grappler import tf_optimizer
|
||||||
from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
|
from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
|
||||||
|
|
||||||
|
# Keras functions used by TFLite
|
||||||
|
model_input_signature = _tflite_keras_util.model_input_signature
|
||||||
|
trace_model_call = _tflite_keras_util.trace_model_call
|
||||||
|
|
||||||
# Map of tf.dtypes to TFLite types_flag_pb2.
|
# Map of tf.dtypes to TFLite types_flag_pb2.
|
||||||
_MAP_TF_TO_TFLITE_TYPES = {
|
_MAP_TF_TO_TFLITE_TYPES = {
|
||||||
dtypes.float32: _types_pb2.FLOAT,
|
dtypes.float32: _types_pb2.FLOAT,
|
||||||
|
@ -417,6 +417,8 @@ def call_context():
|
|||||||
return call_ctx
|
return call_ctx
|
||||||
|
|
||||||
|
|
||||||
|
# Inject the call_context function to keras_deps to remove the dependency
|
||||||
|
# from TFLite to Keras.
|
||||||
keras_deps.register_call_context_function(call_context)
|
keras_deps.register_call_context_function(call_context)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user