Enable export of train/eval graphs for Keras models that use OptimizerV2.

Since OptimizerV2's get_config requires using the backend to get the values of hyperparameters, some code had to be moved into the Session that is created during export.

PiperOrigin-RevId: 235039100
This commit is contained in:
Katherine Wu 2019-02-21 11:51:28 -08:00 committed by TensorFlower Gardener
parent b78cce6820
commit 1d0ec3ec5d
2 changed files with 39 additions and 35 deletions

View File

@ -25,6 +25,7 @@ from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving import model_from_json
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import mode_keys
@ -182,8 +183,8 @@ def _save_v1_format(model, path, custom_objects, as_text, input_signature):
has_saved_vars = False
if model.optimizer:
# TODO(kathywu): Verify this works with v2 optimizer.
if isinstance(model.optimizer, optimizers.TFOptimizer):
if isinstance(model.optimizer, (optimizers.TFOptimizer,
optimizer_v2.OptimizerV2)):
_export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args)
has_saved_vars = True
_export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args)
@ -268,9 +269,8 @@ def _export_mode(
clone._make_predict_function()
g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
clone_var_list = _get_var_list(clone)
with session.Session().as_default():
clone_var_list = _get_var_list(clone)
if has_saved_vars:
# Confirm all variables in the clone have an entry in the checkpoint.
status = clone.load_weights(checkpoint_path)

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.optimizer_v2 import adadelta
from tensorflow.python.keras.saving import saved_model as keras_saved_model
from tensorflow.python.keras.utils import mode_keys
from tensorflow.python.keras.utils import tf_utils
@ -286,42 +287,45 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
{
'model_builder': functional_model,
'uses_learning_phase': True,
'optimizer': training_module.AdadeltaOptimizer(),
'optimizer_cls': adadelta.Adadelta,
'train_before_export': True},
{
'model_builder': functional_model,
'uses_learning_phase': True,
'optimizer': training_module.AdadeltaOptimizer(),
'optimizer_cls': training_module.AdadeltaOptimizer,
'train_before_export': False},
{
'model_builder': functional_model,
'uses_learning_phase': False,
'optimizer': None,
'optimizer_cls': None,
'train_before_export': False},
{
'model_builder': sequential_model,
'uses_learning_phase': True,
'optimizer': training_module.AdadeltaOptimizer(),
'optimizer_cls': training_module.AdadeltaOptimizer,
'train_before_export': True},
{
'model_builder': sequential_model,
'uses_learning_phase': True,
'optimizer': training_module.AdadeltaOptimizer(),
'optimizer_cls': adadelta.Adadelta,
'train_before_export': False},
{
'model_builder': sequential_model,
'uses_learning_phase': False,
'optimizer': None,
'optimizer_cls': None,
'train_before_export': False},
{
'model_builder': sequential_model_without_input_shape,
'uses_learning_phase': True,
'optimizer': training_module.AdadeltaOptimizer(),
'optimizer_cls': training_module.AdadeltaOptimizer,
'train_before_export': False})
def testSaveAndLoadSavedModelExport(
self, model_builder, uses_learning_phase, optimizer, train_before_export):
self, model_builder, uses_learning_phase, optimizer_cls,
train_before_export):
optimizer = None if optimizer_cls is None else optimizer_cls()
saved_model_dir = self._save_model_dir()
with self.session(graph=ops.Graph()):
np.random.seed(130)
input_arr = np.random.random((1, 3))
target_arr = np.random.random((1, 3))