Enable export of train/eval graphs for Keras models that use OptimizerV2.
Since OptimizerV2's get_config requires using the backend to get the values of hyperparameters, some code had to be moved into the Session that is created during export. PiperOrigin-RevId: 235039100
This commit is contained in:
parent
b78cce6820
commit
1d0ec3ec5d
@ -25,6 +25,7 @@ from tensorflow.python.client import session
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras import optimizers
|
from tensorflow.python.keras import optimizers
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||||
from tensorflow.python.keras.saving import model_from_json
|
from tensorflow.python.keras.saving import model_from_json
|
||||||
from tensorflow.python.keras.saving import saving_utils
|
from tensorflow.python.keras.saving import saving_utils
|
||||||
from tensorflow.python.keras.utils import mode_keys
|
from tensorflow.python.keras.utils import mode_keys
|
||||||
@ -182,8 +183,8 @@ def _save_v1_format(model, path, custom_objects, as_text, input_signature):
|
|||||||
|
|
||||||
has_saved_vars = False
|
has_saved_vars = False
|
||||||
if model.optimizer:
|
if model.optimizer:
|
||||||
# TODO(kathywu): Verify this works with v2 optimizer.
|
if isinstance(model.optimizer, (optimizers.TFOptimizer,
|
||||||
if isinstance(model.optimizer, optimizers.TFOptimizer):
|
optimizer_v2.OptimizerV2)):
|
||||||
_export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args)
|
_export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args)
|
||||||
has_saved_vars = True
|
has_saved_vars = True
|
||||||
_export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args)
|
_export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args)
|
||||||
@ -268,9 +269,8 @@ def _export_mode(
|
|||||||
clone._make_predict_function()
|
clone._make_predict_function()
|
||||||
g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
|
g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
|
||||||
|
|
||||||
clone_var_list = _get_var_list(clone)
|
|
||||||
|
|
||||||
with session.Session().as_default():
|
with session.Session().as_default():
|
||||||
|
clone_var_list = _get_var_list(clone)
|
||||||
if has_saved_vars:
|
if has_saved_vars:
|
||||||
# Confirm all variables in the clone have an entry in the checkpoint.
|
# Confirm all variables in the clone have an entry in the checkpoint.
|
||||||
status = clone.load_weights(checkpoint_path)
|
status = clone.load_weights(checkpoint_path)
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.engine import training
|
from tensorflow.python.keras.engine import training
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import adadelta
|
||||||
from tensorflow.python.keras.saving import saved_model as keras_saved_model
|
from tensorflow.python.keras.saving import saved_model as keras_saved_model
|
||||||
from tensorflow.python.keras.utils import mode_keys
|
from tensorflow.python.keras.utils import mode_keys
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
@ -286,42 +287,45 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
|||||||
{
|
{
|
||||||
'model_builder': functional_model,
|
'model_builder': functional_model,
|
||||||
'uses_learning_phase': True,
|
'uses_learning_phase': True,
|
||||||
'optimizer': training_module.AdadeltaOptimizer(),
|
'optimizer_cls': adadelta.Adadelta,
|
||||||
'train_before_export': True},
|
'train_before_export': True},
|
||||||
{
|
{
|
||||||
'model_builder': functional_model,
|
'model_builder': functional_model,
|
||||||
'uses_learning_phase': True,
|
'uses_learning_phase': True,
|
||||||
'optimizer': training_module.AdadeltaOptimizer(),
|
'optimizer_cls': training_module.AdadeltaOptimizer,
|
||||||
'train_before_export': False},
|
'train_before_export': False},
|
||||||
{
|
{
|
||||||
'model_builder': functional_model,
|
'model_builder': functional_model,
|
||||||
'uses_learning_phase': False,
|
'uses_learning_phase': False,
|
||||||
'optimizer': None,
|
'optimizer_cls': None,
|
||||||
'train_before_export': False},
|
'train_before_export': False},
|
||||||
{
|
{
|
||||||
'model_builder': sequential_model,
|
'model_builder': sequential_model,
|
||||||
'uses_learning_phase': True,
|
'uses_learning_phase': True,
|
||||||
'optimizer': training_module.AdadeltaOptimizer(),
|
'optimizer_cls': training_module.AdadeltaOptimizer,
|
||||||
'train_before_export': True},
|
'train_before_export': True},
|
||||||
{
|
{
|
||||||
'model_builder': sequential_model,
|
'model_builder': sequential_model,
|
||||||
'uses_learning_phase': True,
|
'uses_learning_phase': True,
|
||||||
'optimizer': training_module.AdadeltaOptimizer(),
|
'optimizer_cls': adadelta.Adadelta,
|
||||||
'train_before_export': False},
|
'train_before_export': False},
|
||||||
{
|
{
|
||||||
'model_builder': sequential_model,
|
'model_builder': sequential_model,
|
||||||
'uses_learning_phase': False,
|
'uses_learning_phase': False,
|
||||||
'optimizer': None,
|
'optimizer_cls': None,
|
||||||
'train_before_export': False},
|
'train_before_export': False},
|
||||||
{
|
{
|
||||||
'model_builder': sequential_model_without_input_shape,
|
'model_builder': sequential_model_without_input_shape,
|
||||||
'uses_learning_phase': True,
|
'uses_learning_phase': True,
|
||||||
'optimizer': training_module.AdadeltaOptimizer(),
|
'optimizer_cls': training_module.AdadeltaOptimizer,
|
||||||
'train_before_export': False})
|
'train_before_export': False})
|
||||||
def testSaveAndLoadSavedModelExport(
|
def testSaveAndLoadSavedModelExport(
|
||||||
self, model_builder, uses_learning_phase, optimizer, train_before_export):
|
self, model_builder, uses_learning_phase, optimizer_cls,
|
||||||
|
train_before_export):
|
||||||
|
optimizer = None if optimizer_cls is None else optimizer_cls()
|
||||||
|
|
||||||
saved_model_dir = self._save_model_dir()
|
saved_model_dir = self._save_model_dir()
|
||||||
with self.session(graph=ops.Graph()):
|
|
||||||
np.random.seed(130)
|
np.random.seed(130)
|
||||||
input_arr = np.random.random((1, 3))
|
input_arr = np.random.random((1, 3))
|
||||||
target_arr = np.random.random((1, 3))
|
target_arr = np.random.random((1, 3))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user