Enable export of train/eval graphs for Keras models that use OptimizerV2.
Since OptimizerV2's get_config requires using the backend to get the values of hyperparameters, some code had to be moved into the Session that is created during export. PiperOrigin-RevId: 235039100
This commit is contained in:
parent
b78cce6820
commit
1d0ec3ec5d
@ -25,6 +25,7 @@ from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.keras.saving import model_from_json
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.utils import mode_keys
|
||||
@ -182,8 +183,8 @@ def _save_v1_format(model, path, custom_objects, as_text, input_signature):
|
||||
|
||||
has_saved_vars = False
|
||||
if model.optimizer:
|
||||
# TODO(kathywu): Verify this works with v2 optimizer.
|
||||
if isinstance(model.optimizer, optimizers.TFOptimizer):
|
||||
if isinstance(model.optimizer, (optimizers.TFOptimizer,
|
||||
optimizer_v2.OptimizerV2)):
|
||||
_export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args)
|
||||
has_saved_vars = True
|
||||
_export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args)
|
||||
@ -268,9 +269,8 @@ def _export_mode(
|
||||
clone._make_predict_function()
|
||||
g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
|
||||
|
||||
clone_var_list = _get_var_list(clone)
|
||||
|
||||
with session.Session().as_default():
|
||||
clone_var_list = _get_var_list(clone)
|
||||
if has_saved_vars:
|
||||
# Confirm all variables in the clone have an entry in the checkpoint.
|
||||
status = clone.load_weights(checkpoint_path)
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.optimizer_v2 import adadelta
|
||||
from tensorflow.python.keras.saving import saved_model as keras_saved_model
|
||||
from tensorflow.python.keras.utils import mode_keys
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
@ -286,42 +287,45 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||
{
|
||||
'model_builder': functional_model,
|
||||
'uses_learning_phase': True,
|
||||
'optimizer': training_module.AdadeltaOptimizer(),
|
||||
'optimizer_cls': adadelta.Adadelta,
|
||||
'train_before_export': True},
|
||||
{
|
||||
'model_builder': functional_model,
|
||||
'uses_learning_phase': True,
|
||||
'optimizer': training_module.AdadeltaOptimizer(),
|
||||
'optimizer_cls': training_module.AdadeltaOptimizer,
|
||||
'train_before_export': False},
|
||||
{
|
||||
'model_builder': functional_model,
|
||||
'uses_learning_phase': False,
|
||||
'optimizer': None,
|
||||
'optimizer_cls': None,
|
||||
'train_before_export': False},
|
||||
{
|
||||
'model_builder': sequential_model,
|
||||
'uses_learning_phase': True,
|
||||
'optimizer': training_module.AdadeltaOptimizer(),
|
||||
'optimizer_cls': training_module.AdadeltaOptimizer,
|
||||
'train_before_export': True},
|
||||
{
|
||||
'model_builder': sequential_model,
|
||||
'uses_learning_phase': True,
|
||||
'optimizer': training_module.AdadeltaOptimizer(),
|
||||
'optimizer_cls': adadelta.Adadelta,
|
||||
'train_before_export': False},
|
||||
{
|
||||
'model_builder': sequential_model,
|
||||
'uses_learning_phase': False,
|
||||
'optimizer': None,
|
||||
'optimizer_cls': None,
|
||||
'train_before_export': False},
|
||||
{
|
||||
'model_builder': sequential_model_without_input_shape,
|
||||
'uses_learning_phase': True,
|
||||
'optimizer': training_module.AdadeltaOptimizer(),
|
||||
'optimizer_cls': training_module.AdadeltaOptimizer,
|
||||
'train_before_export': False})
|
||||
def testSaveAndLoadSavedModelExport(
|
||||
self, model_builder, uses_learning_phase, optimizer, train_before_export):
|
||||
self, model_builder, uses_learning_phase, optimizer_cls,
|
||||
train_before_export):
|
||||
optimizer = None if optimizer_cls is None else optimizer_cls()
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
with self.session(graph=ops.Graph()):
|
||||
|
||||
np.random.seed(130)
|
||||
input_arr = np.random.random((1, 3))
|
||||
target_arr = np.random.random((1, 3))
|
||||
|
Loading…
x
Reference in New Issue
Block a user