Remove keras completely from lite.py
PiperOrigin-RevId: 339385259 Change-Id: I8ffab8c686870a42c60565d863506a398c1e1238
This commit is contained in:
parent
091f679cdf
commit
b368310cbf
tensorflow
lite/python
python
@ -86,6 +86,7 @@ py_library(
|
||||
":lite",
|
||||
"//tensorflow/lite/toco/logging:gen_html",
|
||||
"//tensorflow/lite/toco/logging:toco_conversion_log_proto_py",
|
||||
"//tensorflow/python:util",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
@ -111,6 +112,7 @@ py_test(
|
||||
deps = [
|
||||
":convert",
|
||||
":tflite_convert",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
|
@ -66,7 +66,6 @@ from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
||||
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
|
||||
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
|
||||
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
|
||||
from tensorflow.python import keras as _keras
|
||||
from tensorflow.python.client import session as _session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function as _def_function
|
||||
@ -83,6 +82,7 @@ from tensorflow.python.saved_model import tag_constants as _tag_constants
|
||||
from tensorflow.python.saved_model.load import load as _load
|
||||
from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
|
||||
from tensorflow.python.util import deprecation as _deprecation
|
||||
from tensorflow.python.util import keras_deps
|
||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||
|
||||
|
||||
@ -1466,10 +1466,8 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
||||
"with Eager mode. If your model requires any of these "
|
||||
"parameters, please use disable_eager_execution().")
|
||||
|
||||
_keras.backend.set_learning_phase(False)
|
||||
keras_model = _keras.models.load_model(model_file, custom_objects)
|
||||
|
||||
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
|
||||
keras_model = keras_deps.get_load_model_function()(
|
||||
model_file, custom_objects)
|
||||
function = _keras_saving_utils.trace_model_call(keras_model)
|
||||
concrete_func = function.get_concrete_function()
|
||||
|
||||
@ -1484,10 +1482,10 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
||||
return
|
||||
|
||||
# Handles Keras when Eager mode is disabled.
|
||||
_keras.backend.clear_session()
|
||||
_keras.backend.set_learning_phase(False)
|
||||
keras_model = _keras.models.load_model(model_file, custom_objects)
|
||||
sess = _keras.backend.get_session()
|
||||
keras_deps.get_clear_session_function()()
|
||||
keras_model = keras_deps.get_load_model_function()(
|
||||
model_file, custom_objects)
|
||||
sess = keras_deps.get_get_session_function()()
|
||||
|
||||
# Get input and output tensors.
|
||||
if input_arrays:
|
||||
|
@ -31,10 +31,10 @@ from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python.convert import register_custom_opdefs
|
||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
||||
from tensorflow.lite.toco.logging import gen_html
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.util import keras_deps
|
||||
|
||||
|
||||
def _parse_array(values, type_fn=str):
|
||||
@ -234,7 +234,7 @@ def _convert_tf2_model(flags):
|
||||
if flags.saved_model_dir:
|
||||
converter = lite.TFLiteConverterV2.from_saved_model(flags.saved_model_dir)
|
||||
elif flags.keras_model_file:
|
||||
model = keras.models.load_model(flags.keras_model_file)
|
||||
model = keras_deps.get_load_model_function()(flags.keras_model_file)
|
||||
converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||
|
||||
if flags.experimental_new_converter is not None:
|
||||
|
@ -21,11 +21,11 @@ from __future__ import print_function
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from tensorflow import keras
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.lite.python import tflite_convert
|
||||
from tensorflow.lite.python.convert import register_custom_opdefs
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import def_function
|
||||
|
@ -82,6 +82,7 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import moving_averages
|
||||
from tensorflow.python.training.tracking import util as tracking_util
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util import keras_deps
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -318,6 +319,10 @@ def clear_session():
|
||||
_GRAPH_VARIABLES.pop(graph, None)
|
||||
_GRAPH_TF_OPTIMIZERS.pop(graph, None)
|
||||
|
||||
# Inject the clear_session function to keras_deps to remove the dependency
|
||||
# from TFLite to Keras.
|
||||
keras_deps.register_clear_session_function(clear_session)
|
||||
|
||||
|
||||
@keras_export('keras.backend.manual_variable_initialization')
|
||||
@doc_controls.do_not_generate_docs
|
||||
@ -745,6 +750,9 @@ def get_session(op_input_list=()):
|
||||
_initialize_variables(session)
|
||||
return session
|
||||
|
||||
# Inject the get_session function to keras_deps to remove the dependency
|
||||
# from TFLite to Keras.
|
||||
keras_deps.register_get_session_function(get_session)
|
||||
|
||||
# Inject the get_session function to tracking_util to avoid the backward
|
||||
# dependency from TF to Keras.
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||
from tensorflow.python.saved_model import load_context
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.util import keras_deps
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
@ -214,3 +215,7 @@ def load_model(filepath, custom_objects=None, compile=True, options=None): # py
|
||||
raise IOError(
|
||||
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
|
||||
'available) or SavedModel.')
|
||||
|
||||
# Inject the load_model function to keras_deps to remove the dependency
|
||||
# from TFLite to Keras.
|
||||
keras_deps.register_load_model_function(load_model)
|
||||
|
@ -30,15 +30,52 @@ from __future__ import print_function
|
||||
|
||||
|
||||
_KERAS_CALL_CONTEXT_FUNCTION = None
|
||||
_KERAS_CLEAR_SESSION_FUNCTION = None
|
||||
_KERAS_GET_SESSION_FUNCTION = None
|
||||
_KERAS_LOAD_MODEL_FUNCTION = None
|
||||
|
||||
# TODO(scottzhu): Disable duplicated inject once keras is moved to
|
||||
# third_party/py/keras.
|
||||
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
|
||||
|
||||
|
||||
# Register functions
|
||||
def register_call_context_function(func):
|
||||
global _KERAS_CALL_CONTEXT_FUNCTION
|
||||
# TODO(scottzhu): Disable duplicated inject once keras is moved to
|
||||
# third_party/py/keras.
|
||||
_KERAS_CALL_CONTEXT_FUNCTION = func
|
||||
|
||||
|
||||
def register_clear_session_function(func):
|
||||
global _KERAS_CLEAR_SESSION_FUNCTION
|
||||
_KERAS_CLEAR_SESSION_FUNCTION = func
|
||||
|
||||
|
||||
def register_get_session_function(func):
|
||||
global _KERAS_GET_SESSION_FUNCTION
|
||||
_KERAS_GET_SESSION_FUNCTION = func
|
||||
|
||||
|
||||
def register_load_model_function(func):
|
||||
global _KERAS_LOAD_MODEL_FUNCTION
|
||||
_KERAS_LOAD_MODEL_FUNCTION = func
|
||||
|
||||
|
||||
# Get functions
|
||||
def get_call_context_function():
|
||||
global _KERAS_CALL_CONTEXT_FUNCTION
|
||||
return _KERAS_CALL_CONTEXT_FUNCTION
|
||||
|
||||
|
||||
def get_clear_session_function():
|
||||
global _KERAS_CLEAR_SESSION_FUNCTION
|
||||
return _KERAS_CLEAR_SESSION_FUNCTION
|
||||
|
||||
|
||||
def get_get_session_function():
|
||||
global _KERAS_GET_SESSION_FUNCTION
|
||||
return _KERAS_GET_SESSION_FUNCTION
|
||||
|
||||
|
||||
def get_load_model_function():
|
||||
global _KERAS_LOAD_MODEL_FUNCTION
|
||||
return _KERAS_LOAD_MODEL_FUNCTION
|
||||
|
Loading…
Reference in New Issue
Block a user