Add support for Eager-mode in from_keras_model_file in 1.X.

PiperOrigin-RevId: 243897767
This commit is contained in:
Nupur Garg 2019-04-16 16:04:13 -07:00 committed by TensorFlower Gardener
parent ae943f1f3a
commit 7708e8c53a

View File

@ -50,6 +50,7 @@ from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function as _def_function
from tensorflow.python.eager import function as _function
from tensorflow.python.framework import convert_to_constants as _convert_to_constants
@ -694,6 +695,26 @@ class TFLiteConverter(object):
Returns:
TFLiteConverter class.
"""
# Handles Keras when Eager mode is enabled.
if context.executing_eagerly():
if input_arrays or input_shapes or output_arrays:
raise ValueError("`input_arrays`, `input_shapes` and `output_arrays`"
"are unsupported with Eager mode. If your model "
"requires any of these parameters, please use "
"disable_eager_execution().")
_keras.backend.set_learning_phase(False)
keras_model = _keras.models.load_model(model_file, custom_objects)
function = _saving_utils.trace_model_call(keras_model)
concrete_func = function.get_concrete_function()
frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
concrete_func)
return cls(frozen_func.graph.as_graph_def(), frozen_func.inputs,
frozen_func.outputs)
# Handles Keras when Eager mode is disabled.
_keras.backend.clear_session()
_keras.backend.set_learning_phase(False)
keras_model = _keras.models.load_model(model_file, custom_objects)