Remove keras completely from lite.py

PiperOrigin-RevId: 339385259
Change-Id: I8ffab8c686870a42c60565d863506a398c1e1238
This commit is contained in:
Meghna Natraj 2020-10-27 20:37:33 -07:00 committed by TensorFlower Gardener
parent 091f679cdf
commit b368310cbf
7 changed files with 64 additions and 14 deletions

View File

@ -86,6 +86,7 @@ py_library(
":lite",
"//tensorflow/lite/toco/logging:gen_html",
"//tensorflow/lite/toco/logging:toco_conversion_log_proto_py",
"//tensorflow/python:util",
"@six_archive//:six",
],
)
@ -111,6 +112,7 @@ py_test(
deps = [
":convert",
":tflite_convert",
"//tensorflow:tensorflow_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",

View File

@ -66,7 +66,6 @@ from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function as _def_function
@ -83,6 +82,7 @@ from tensorflow.python.saved_model import tag_constants as _tag_constants
from tensorflow.python.saved_model.load import load as _load
from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
from tensorflow.python.util import deprecation as _deprecation
from tensorflow.python.util import keras_deps
from tensorflow.python.util.tf_export import tf_export as _tf_export
@ -1466,10 +1466,8 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
"with Eager mode. If your model requires any of these "
"parameters, please use disable_eager_execution().")
_keras.backend.set_learning_phase(False)
keras_model = _keras.models.load_model(model_file, custom_objects)
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
keras_model = keras_deps.get_load_model_function()(
model_file, custom_objects)
function = _keras_saving_utils.trace_model_call(keras_model)
concrete_func = function.get_concrete_function()
@ -1484,10 +1482,10 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
return
# Handles Keras when Eager mode is disabled.
_keras.backend.clear_session()
_keras.backend.set_learning_phase(False)
keras_model = _keras.models.load_model(model_file, custom_objects)
sess = _keras.backend.get_session()
keras_deps.get_clear_session_function()()
keras_model = keras_deps.get_load_model_function()(
model_file, custom_objects)
sess = keras_deps.get_get_session_function()()
# Get input and output tensors.
if input_arrays:

View File

@ -31,10 +31,10 @@ from tensorflow.lite.python import lite
from tensorflow.lite.python.convert import register_custom_opdefs
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.lite.toco.logging import gen_html
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import app
from tensorflow.python.util import keras_deps
def _parse_array(values, type_fn=str):
@ -234,7 +234,7 @@ def _convert_tf2_model(flags):
if flags.saved_model_dir:
converter = lite.TFLiteConverterV2.from_saved_model(flags.saved_model_dir)
elif flags.keras_model_file:
model = keras.models.load_model(flags.keras_model_file)
model = keras_deps.get_load_model_function()(flags.keras_model_file)
converter = lite.TFLiteConverterV2.from_keras_model(model)
if flags.experimental_new_converter is not None:

View File

@ -21,11 +21,11 @@ from __future__ import print_function
import os
import numpy as np
from tensorflow import keras
from tensorflow.core.framework import graph_pb2
from tensorflow.lite.python import tflite_convert
from tensorflow.lite.python.convert import register_custom_opdefs
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.client import session
from tensorflow.python.eager import def_function

View File

@ -82,6 +82,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import moving_averages
from tensorflow.python.training.tracking import util as tracking_util
from tensorflow.python.util import dispatch
from tensorflow.python.util import keras_deps
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import keras_export
@ -318,6 +319,10 @@ def clear_session():
_GRAPH_VARIABLES.pop(graph, None)
_GRAPH_TF_OPTIMIZERS.pop(graph, None)
# Inject the clear_session function to keras_deps to remove the dependency
# from TFLite to Keras.
keras_deps.register_clear_session_function(clear_session)
@keras_export('keras.backend.manual_variable_initialization')
@doc_controls.do_not_generate_docs
@ -745,6 +750,9 @@ def get_session(op_input_list=()):
_initialize_variables(session)
return session
# Inject the get_session function to keras_deps to remove the dependency
# from TFLite to Keras.
keras_deps.register_get_session_function(get_session)
# Inject the get_session function to tracking_util to avoid the backward
# dependency from TF to Keras.

View File

@ -29,6 +29,7 @@ from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils.io_utils import path_to_string
from tensorflow.python.saved_model import load_context
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.util import keras_deps
from tensorflow.python.util.tf_export import keras_export
# pylint: disable=g-import-not-at-top
@ -214,3 +215,7 @@ def load_model(filepath, custom_objects=None, compile=True, options=None): # py
raise IOError(
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
'available) or SavedModel.')
# Inject the load_model function to keras_deps to remove the dependency
# from TFLite to Keras.
keras_deps.register_load_model_function(load_model)

View File

@ -30,15 +30,52 @@ from __future__ import print_function
_KERAS_CALL_CONTEXT_FUNCTION = None
_KERAS_CLEAR_SESSION_FUNCTION = None
_KERAS_GET_SESSION_FUNCTION = None
_KERAS_LOAD_MODEL_FUNCTION = None
# TODO(scottzhu): Disable duplicated inject once keras is moved to
# third_party/py/keras.
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
# Register functions
def register_call_context_function(func):
global _KERAS_CALL_CONTEXT_FUNCTION
# TODO(scottzhu): Disable duplicated inject once keras is moved to
# third_party/py/keras.
_KERAS_CALL_CONTEXT_FUNCTION = func
def register_clear_session_function(func):
global _KERAS_CLEAR_SESSION_FUNCTION
_KERAS_CLEAR_SESSION_FUNCTION = func
def register_get_session_function(func):
global _KERAS_GET_SESSION_FUNCTION
_KERAS_GET_SESSION_FUNCTION = func
def register_load_model_function(func):
global _KERAS_LOAD_MODEL_FUNCTION
_KERAS_LOAD_MODEL_FUNCTION = func
# Get functions
def get_call_context_function():
global _KERAS_CALL_CONTEXT_FUNCTION
return _KERAS_CALL_CONTEXT_FUNCTION
def get_clear_session_function():
global _KERAS_CLEAR_SESSION_FUNCTION
return _KERAS_CLEAR_SESSION_FUNCTION
def get_get_session_function():
global _KERAS_GET_SESSION_FUNCTION
return _KERAS_GET_SESSION_FUNCTION
def get_load_model_function():
global _KERAS_LOAD_MODEL_FUNCTION
return _KERAS_LOAD_MODEL_FUNCTION