Remove keras completely from lite.py

PiperOrigin-RevId: 339385259
Change-Id: I8ffab8c686870a42c60565d863506a398c1e1238
This commit is contained in:
Meghna Natraj 2020-10-27 20:37:33 -07:00 committed by TensorFlower Gardener
parent 091f679cdf
commit b368310cbf
7 changed files with 64 additions and 14 deletions

View File

@ -86,6 +86,7 @@ py_library(
":lite", ":lite",
"//tensorflow/lite/toco/logging:gen_html", "//tensorflow/lite/toco/logging:gen_html",
"//tensorflow/lite/toco/logging:toco_conversion_log_proto_py", "//tensorflow/lite/toco/logging:toco_conversion_log_proto_py",
"//tensorflow/python:util",
"@six_archive//:six", "@six_archive//:six",
], ],
) )
@ -111,6 +112,7 @@ py_test(
deps = [ deps = [
":convert", ":convert",
":tflite_convert", ":tflite_convert",
"//tensorflow:tensorflow_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",

View File

@ -66,7 +66,6 @@ from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session from tensorflow.python.client import session as _session
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function as _def_function from tensorflow.python.eager import def_function as _def_function
@ -83,6 +82,7 @@ from tensorflow.python.saved_model import tag_constants as _tag_constants
from tensorflow.python.saved_model.load import load as _load from tensorflow.python.saved_model.load import load as _load
from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
from tensorflow.python.util import deprecation as _deprecation from tensorflow.python.util import deprecation as _deprecation
from tensorflow.python.util import keras_deps
from tensorflow.python.util.tf_export import tf_export as _tf_export from tensorflow.python.util.tf_export import tf_export as _tf_export
@ -1466,10 +1466,8 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
"with Eager mode. If your model requires any of these " "with Eager mode. If your model requires any of these "
"parameters, please use disable_eager_execution().") "parameters, please use disable_eager_execution().")
_keras.backend.set_learning_phase(False) keras_model = keras_deps.get_load_model_function()(
keras_model = _keras.models.load_model(model_file, custom_objects) model_file, custom_objects)
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
function = _keras_saving_utils.trace_model_call(keras_model) function = _keras_saving_utils.trace_model_call(keras_model)
concrete_func = function.get_concrete_function() concrete_func = function.get_concrete_function()
@ -1484,10 +1482,10 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
return return
# Handles Keras when Eager mode is disabled. # Handles Keras when Eager mode is disabled.
_keras.backend.clear_session() keras_deps.get_clear_session_function()()
_keras.backend.set_learning_phase(False) keras_model = keras_deps.get_load_model_function()(
keras_model = _keras.models.load_model(model_file, custom_objects) model_file, custom_objects)
sess = _keras.backend.get_session() sess = keras_deps.get_get_session_function()()
# Get input and output tensors. # Get input and output tensors.
if input_arrays: if input_arrays:

View File

@ -31,10 +31,10 @@ from tensorflow.lite.python import lite
from tensorflow.lite.python.convert import register_custom_opdefs from tensorflow.lite.python.convert import register_custom_opdefs
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.lite.toco.logging import gen_html from tensorflow.lite.toco.logging import gen_html
from tensorflow.python import keras
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.platform import app from tensorflow.python.platform import app
from tensorflow.python.util import keras_deps
def _parse_array(values, type_fn=str): def _parse_array(values, type_fn=str):
@ -234,7 +234,7 @@ def _convert_tf2_model(flags):
if flags.saved_model_dir: if flags.saved_model_dir:
converter = lite.TFLiteConverterV2.from_saved_model(flags.saved_model_dir) converter = lite.TFLiteConverterV2.from_saved_model(flags.saved_model_dir)
elif flags.keras_model_file: elif flags.keras_model_file:
model = keras.models.load_model(flags.keras_model_file) model = keras_deps.get_load_model_function()(flags.keras_model_file)
converter = lite.TFLiteConverterV2.from_keras_model(model) converter = lite.TFLiteConverterV2.from_keras_model(model)
if flags.experimental_new_converter is not None: if flags.experimental_new_converter is not None:

View File

@ -21,11 +21,11 @@ from __future__ import print_function
import os import os
import numpy as np import numpy as np
from tensorflow import keras
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.lite.python import tflite_convert from tensorflow.lite.python import tflite_convert
from tensorflow.lite.python.convert import register_custom_opdefs from tensorflow.lite.python.convert import register_custom_opdefs
from tensorflow.python import keras
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function

View File

@ -82,6 +82,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
from tensorflow.python.training.tracking import util as tracking_util from tensorflow.python.training.tracking import util as tracking_util
from tensorflow.python.util import dispatch from tensorflow.python.util import dispatch
from tensorflow.python.util import keras_deps
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import object_identity from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -318,6 +319,10 @@ def clear_session():
_GRAPH_VARIABLES.pop(graph, None) _GRAPH_VARIABLES.pop(graph, None)
_GRAPH_TF_OPTIMIZERS.pop(graph, None) _GRAPH_TF_OPTIMIZERS.pop(graph, None)
# Inject the clear_session function to keras_deps to remove the dependency
# from TFLite to Keras.
keras_deps.register_clear_session_function(clear_session)
@keras_export('keras.backend.manual_variable_initialization') @keras_export('keras.backend.manual_variable_initialization')
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
@ -745,6 +750,9 @@ def get_session(op_input_list=()):
_initialize_variables(session) _initialize_variables(session)
return session return session
# Inject the get_session function to keras_deps to remove the dependency
# from TFLite to Keras.
keras_deps.register_get_session_function(get_session)
# Inject the get_session function to tracking_util to avoid the backward # Inject the get_session function to tracking_util to avoid the backward
# dependency from TF to Keras. # dependency from TF to Keras.

View File

@ -29,6 +29,7 @@ from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils.io_utils import path_to_string from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.saved_model import load_context from tensorflow.python.saved_model import load_context
from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import loader_impl
from tensorflow.python.util import keras_deps
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
@ -214,3 +215,7 @@ def load_model(filepath, custom_objects=None, compile=True, options=None): # py
raise IOError( raise IOError(
'Unable to load model. Filepath is not an hdf5 file (or h5py is not ' 'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
'available) or SavedModel.') 'available) or SavedModel.')
# Inject the load_model function to keras_deps to remove the dependency
# from TFLite to Keras.
keras_deps.register_load_model_function(load_model)

View File

@ -30,15 +30,52 @@ from __future__ import print_function
_KERAS_CALL_CONTEXT_FUNCTION = None _KERAS_CALL_CONTEXT_FUNCTION = None
_KERAS_CLEAR_SESSION_FUNCTION = None
_KERAS_GET_SESSION_FUNCTION = None
_KERAS_LOAD_MODEL_FUNCTION = None
def register_call_context_function(func):
global _KERAS_CALL_CONTEXT_FUNCTION
# TODO(scottzhu): Disable duplicated inject once keras is moved to # TODO(scottzhu): Disable duplicated inject once keras is moved to
# third_party/py/keras. # third_party/py/keras.
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
# Register functions
def register_call_context_function(func):
global _KERAS_CALL_CONTEXT_FUNCTION
_KERAS_CALL_CONTEXT_FUNCTION = func _KERAS_CALL_CONTEXT_FUNCTION = func
def register_clear_session_function(func):
global _KERAS_CLEAR_SESSION_FUNCTION
_KERAS_CLEAR_SESSION_FUNCTION = func
def register_get_session_function(func):
global _KERAS_GET_SESSION_FUNCTION
_KERAS_GET_SESSION_FUNCTION = func
def register_load_model_function(func):
global _KERAS_LOAD_MODEL_FUNCTION
_KERAS_LOAD_MODEL_FUNCTION = func
# Get functions
def get_call_context_function(): def get_call_context_function():
global _KERAS_CALL_CONTEXT_FUNCTION global _KERAS_CALL_CONTEXT_FUNCTION
return _KERAS_CALL_CONTEXT_FUNCTION return _KERAS_CALL_CONTEXT_FUNCTION
def get_clear_session_function():
global _KERAS_CLEAR_SESSION_FUNCTION
return _KERAS_CLEAR_SESSION_FUNCTION
def get_get_session_function():
global _KERAS_GET_SESSION_FUNCTION
return _KERAS_GET_SESSION_FUNCTION
def get_load_model_function():
global _KERAS_LOAD_MODEL_FUNCTION
return _KERAS_LOAD_MODEL_FUNCTION