Remove keras completely from lite.py
PiperOrigin-RevId: 339385259 Change-Id: I8ffab8c686870a42c60565d863506a398c1e1238
This commit is contained in:
parent
091f679cdf
commit
b368310cbf
@ -86,6 +86,7 @@ py_library(
|
|||||||
":lite",
|
":lite",
|
||||||
"//tensorflow/lite/toco/logging:gen_html",
|
"//tensorflow/lite/toco/logging:gen_html",
|
||||||
"//tensorflow/lite/toco/logging:toco_conversion_log_proto_py",
|
"//tensorflow/lite/toco/logging:toco_conversion_log_proto_py",
|
||||||
|
"//tensorflow/python:util",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -111,6 +112,7 @@ py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":convert",
|
":convert",
|
||||||
":tflite_convert",
|
":tflite_convert",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
|
@ -66,7 +66,6 @@ from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
|||||||
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
|
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
|
||||||
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
|
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
|
||||||
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
|
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
|
||||||
from tensorflow.python import keras as _keras
|
|
||||||
from tensorflow.python.client import session as _session
|
from tensorflow.python.client import session as _session
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function as _def_function
|
from tensorflow.python.eager import def_function as _def_function
|
||||||
@ -83,6 +82,7 @@ from tensorflow.python.saved_model import tag_constants as _tag_constants
|
|||||||
from tensorflow.python.saved_model.load import load as _load
|
from tensorflow.python.saved_model.load import load as _load
|
||||||
from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
|
from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
|
||||||
from tensorflow.python.util import deprecation as _deprecation
|
from tensorflow.python.util import deprecation as _deprecation
|
||||||
|
from tensorflow.python.util import keras_deps
|
||||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -1466,10 +1466,8 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
|||||||
"with Eager mode. If your model requires any of these "
|
"with Eager mode. If your model requires any of these "
|
||||||
"parameters, please use disable_eager_execution().")
|
"parameters, please use disable_eager_execution().")
|
||||||
|
|
||||||
_keras.backend.set_learning_phase(False)
|
keras_model = keras_deps.get_load_model_function()(
|
||||||
keras_model = _keras.models.load_model(model_file, custom_objects)
|
model_file, custom_objects)
|
||||||
|
|
||||||
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
|
|
||||||
function = _keras_saving_utils.trace_model_call(keras_model)
|
function = _keras_saving_utils.trace_model_call(keras_model)
|
||||||
concrete_func = function.get_concrete_function()
|
concrete_func = function.get_concrete_function()
|
||||||
|
|
||||||
@ -1484,10 +1482,10 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Handles Keras when Eager mode is disabled.
|
# Handles Keras when Eager mode is disabled.
|
||||||
_keras.backend.clear_session()
|
keras_deps.get_clear_session_function()()
|
||||||
_keras.backend.set_learning_phase(False)
|
keras_model = keras_deps.get_load_model_function()(
|
||||||
keras_model = _keras.models.load_model(model_file, custom_objects)
|
model_file, custom_objects)
|
||||||
sess = _keras.backend.get_session()
|
sess = keras_deps.get_get_session_function()()
|
||||||
|
|
||||||
# Get input and output tensors.
|
# Get input and output tensors.
|
||||||
if input_arrays:
|
if input_arrays:
|
||||||
|
@ -31,10 +31,10 @@ from tensorflow.lite.python import lite
|
|||||||
from tensorflow.lite.python.convert import register_custom_opdefs
|
from tensorflow.lite.python.convert import register_custom_opdefs
|
||||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
||||||
from tensorflow.lite.toco.logging import gen_html
|
from tensorflow.lite.toco.logging import gen_html
|
||||||
from tensorflow.python import keras
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.platform import app
|
from tensorflow.python.platform import app
|
||||||
|
from tensorflow.python.util import keras_deps
|
||||||
|
|
||||||
|
|
||||||
def _parse_array(values, type_fn=str):
|
def _parse_array(values, type_fn=str):
|
||||||
@ -234,7 +234,7 @@ def _convert_tf2_model(flags):
|
|||||||
if flags.saved_model_dir:
|
if flags.saved_model_dir:
|
||||||
converter = lite.TFLiteConverterV2.from_saved_model(flags.saved_model_dir)
|
converter = lite.TFLiteConverterV2.from_saved_model(flags.saved_model_dir)
|
||||||
elif flags.keras_model_file:
|
elif flags.keras_model_file:
|
||||||
model = keras.models.load_model(flags.keras_model_file)
|
model = keras_deps.get_load_model_function()(flags.keras_model_file)
|
||||||
converter = lite.TFLiteConverterV2.from_keras_model(model)
|
converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||||
|
|
||||||
if flags.experimental_new_converter is not None:
|
if flags.experimental_new_converter is not None:
|
||||||
|
@ -21,11 +21,11 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tensorflow import keras
|
||||||
|
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.lite.python import tflite_convert
|
from tensorflow.lite.python import tflite_convert
|
||||||
from tensorflow.lite.python.convert import register_custom_opdefs
|
from tensorflow.lite.python.convert import register_custom_opdefs
|
||||||
from tensorflow.python import keras
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
|
@ -82,6 +82,7 @@ from tensorflow.python.platform import tf_logging as logging
|
|||||||
from tensorflow.python.training import moving_averages
|
from tensorflow.python.training import moving_averages
|
||||||
from tensorflow.python.training.tracking import util as tracking_util
|
from tensorflow.python.training.tracking import util as tracking_util
|
||||||
from tensorflow.python.util import dispatch
|
from tensorflow.python.util import dispatch
|
||||||
|
from tensorflow.python.util import keras_deps
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import object_identity
|
from tensorflow.python.util import object_identity
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
@ -318,6 +319,10 @@ def clear_session():
|
|||||||
_GRAPH_VARIABLES.pop(graph, None)
|
_GRAPH_VARIABLES.pop(graph, None)
|
||||||
_GRAPH_TF_OPTIMIZERS.pop(graph, None)
|
_GRAPH_TF_OPTIMIZERS.pop(graph, None)
|
||||||
|
|
||||||
|
# Inject the clear_session function to keras_deps to remove the dependency
|
||||||
|
# from TFLite to Keras.
|
||||||
|
keras_deps.register_clear_session_function(clear_session)
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.backend.manual_variable_initialization')
|
@keras_export('keras.backend.manual_variable_initialization')
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
@ -745,6 +750,9 @@ def get_session(op_input_list=()):
|
|||||||
_initialize_variables(session)
|
_initialize_variables(session)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
# Inject the get_session function to keras_deps to remove the dependency
|
||||||
|
# from TFLite to Keras.
|
||||||
|
keras_deps.register_get_session_function(get_session)
|
||||||
|
|
||||||
# Inject the get_session function to tracking_util to avoid the backward
|
# Inject the get_session function to tracking_util to avoid the backward
|
||||||
# dependency from TF to Keras.
|
# dependency from TF to Keras.
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.keras.utils import generic_utils
|
|||||||
from tensorflow.python.keras.utils.io_utils import path_to_string
|
from tensorflow.python.keras.utils.io_utils import path_to_string
|
||||||
from tensorflow.python.saved_model import load_context
|
from tensorflow.python.saved_model import load_context
|
||||||
from tensorflow.python.saved_model import loader_impl
|
from tensorflow.python.saved_model import loader_impl
|
||||||
|
from tensorflow.python.util import keras_deps
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
# pylint: disable=g-import-not-at-top
|
# pylint: disable=g-import-not-at-top
|
||||||
@ -214,3 +215,7 @@ def load_model(filepath, custom_objects=None, compile=True, options=None): # py
|
|||||||
raise IOError(
|
raise IOError(
|
||||||
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
|
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
|
||||||
'available) or SavedModel.')
|
'available) or SavedModel.')
|
||||||
|
|
||||||
|
# Inject the load_model function to keras_deps to remove the dependency
|
||||||
|
# from TFLite to Keras.
|
||||||
|
keras_deps.register_load_model_function(load_model)
|
||||||
|
@ -30,15 +30,52 @@ from __future__ import print_function
|
|||||||
|
|
||||||
|
|
||||||
_KERAS_CALL_CONTEXT_FUNCTION = None
|
_KERAS_CALL_CONTEXT_FUNCTION = None
|
||||||
|
_KERAS_CLEAR_SESSION_FUNCTION = None
|
||||||
|
_KERAS_GET_SESSION_FUNCTION = None
|
||||||
|
_KERAS_LOAD_MODEL_FUNCTION = None
|
||||||
|
|
||||||
|
|
||||||
def register_call_context_function(func):
|
|
||||||
global _KERAS_CALL_CONTEXT_FUNCTION
|
|
||||||
# TODO(scottzhu): Disable duplicated inject once keras is moved to
|
# TODO(scottzhu): Disable duplicated inject once keras is moved to
|
||||||
# third_party/py/keras.
|
# third_party/py/keras.
|
||||||
|
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
|
||||||
|
|
||||||
|
|
||||||
|
# Register functions
|
||||||
|
def register_call_context_function(func):
|
||||||
|
global _KERAS_CALL_CONTEXT_FUNCTION
|
||||||
_KERAS_CALL_CONTEXT_FUNCTION = func
|
_KERAS_CALL_CONTEXT_FUNCTION = func
|
||||||
|
|
||||||
|
|
||||||
|
def register_clear_session_function(func):
|
||||||
|
global _KERAS_CLEAR_SESSION_FUNCTION
|
||||||
|
_KERAS_CLEAR_SESSION_FUNCTION = func
|
||||||
|
|
||||||
|
|
||||||
|
def register_get_session_function(func):
|
||||||
|
global _KERAS_GET_SESSION_FUNCTION
|
||||||
|
_KERAS_GET_SESSION_FUNCTION = func
|
||||||
|
|
||||||
|
|
||||||
|
def register_load_model_function(func):
|
||||||
|
global _KERAS_LOAD_MODEL_FUNCTION
|
||||||
|
_KERAS_LOAD_MODEL_FUNCTION = func
|
||||||
|
|
||||||
|
|
||||||
|
# Get functions
|
||||||
def get_call_context_function():
|
def get_call_context_function():
|
||||||
global _KERAS_CALL_CONTEXT_FUNCTION
|
global _KERAS_CALL_CONTEXT_FUNCTION
|
||||||
return _KERAS_CALL_CONTEXT_FUNCTION
|
return _KERAS_CALL_CONTEXT_FUNCTION
|
||||||
|
|
||||||
|
|
||||||
|
def get_clear_session_function():
|
||||||
|
global _KERAS_CLEAR_SESSION_FUNCTION
|
||||||
|
return _KERAS_CLEAR_SESSION_FUNCTION
|
||||||
|
|
||||||
|
|
||||||
|
def get_get_session_function():
|
||||||
|
global _KERAS_GET_SESSION_FUNCTION
|
||||||
|
return _KERAS_GET_SESSION_FUNCTION
|
||||||
|
|
||||||
|
|
||||||
|
def get_load_model_function():
|
||||||
|
global _KERAS_LOAD_MODEL_FUNCTION
|
||||||
|
return _KERAS_LOAD_MODEL_FUNCTION
|
||||||
|
Loading…
Reference in New Issue
Block a user